In [9]:
import sys
import os
import requests
import tqdm
import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image
from skimage.measure import compare_psnr
# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/facebookresearch/mae.git
    sys.path.append('./mae')
else:
    sys.path.append('..')
import models_mae


In [10]:

def run_one_image_average(img,img_ori, model, iterations=50):
    # ImageNet mean and std used for normalization
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
    imagenet_std = torch.tensor([0.229, 0.224, 0.225])

    model.eval()
    model = model.to('cuda')
    
    # Prepare the original image for visualization
    original_img = torch.tensor(img)
    original_img = torch.clip((original_img * imagenet_std + imagenet_mean) * 255, 0, 255).int()

    # Initialize a tensor to accumulate the reconstructions on GPU
    # 确保accumulated_recons是浮点型
    accumulated_recons = torch.zeros_like(torch.tensor(img).unsqueeze(dim=0)).to('cuda').float()

    for _ in range(iterations):
        x = torch.tensor(img).to('cuda')
        x = x.unsqueeze(dim=0)
        x = torch.einsum('nhwc->nchw', x)

        with torch.no_grad():
            _, y, _ = model(x.float(), mask_ratio=0.5)
            y = model.unpatchify(y)
            y = torch.einsum('nchw->nhwc', y)

            # 累加浮点型重建结果
            accumulated_recons += y.float()

    # Calculate the average reconstruction and move it back to CPU
    avg_reconstruction = (accumulated_recons / iterations).detach().cpu()

    # 在可视化之前将平均重建结果转换为整型
    avg_reconstruction = torch.clip((avg_reconstruction * imagenet_std + imagenet_mean) * 255, 0, 255).int()


    # Set up the plot
    plt.rcParams['figure.figsize'] = [24, 24]

    # Plotting the original image
    plt.subplot(1, 2, 1)
    plt.imshow(original_img.numpy())
    plt.title("Original Image", fontsize=16)
    plt.axis('off')

    # Plotting the average reconstruction image
    plt.subplot(1, 2, 2)
    plt.imshow(avg_reconstruction[0].numpy())
    plt.title("Average Reconstruction", fontsize=16)
    plt.axis('off')

    # Display the plot
    plt.show()
    print(np.shape(avg_reconstruction[0].numpy()))
    print(np.shape(img))
    #print(avg_reconstruction[0].numpy()/255.)
    print(img)
    original_img_ori = torch.tensor(img_ori)
    original_img_ori = torch.clip((original_img_ori * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    print(np.shape(original_img))
    psrn_noisy = compare_psnr(avg_reconstruction[0].numpy()/255.,original_img.numpy()/255.)
    psnr_gt = compare_psnr(avg_reconstruction[0].numpy()/255.,original_img_ori.numpy()/255.)
    print ('PSNR_noisy: %f   PSRN_gt: %f ' % (psrn_noisy, psnr_gt), '\r', end='')
    #ssim = 
    #print(psnr)
    #print("SSIM="+ssim)

# define the utils

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()



In [11]:
def MAE_Self2Self_like(img, model, iterations=50):
    # ImageNet mean and std used for normalization
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
    imagenet_std = torch.tensor([0.229, 0.224, 0.225])

    model.eval()
    model = model.to('cuda')
    
    # Prepare the original image for visualization
    original_img = torch.tensor(img)
    original_img = torch.clip((original_img * imagenet_std + imagenet_mean) * 255, 0, 255).int()

    # Initialize a tensor to accumulate the reconstructions on GPU
    # 确保accumulated_recons是浮点型
    accumulated_recons = torch.zeros_like(torch.tensor(img).unsqueeze(dim=0)).to('cuda').float()

    for _ in range(iterations):
        x = torch.tensor(img).to('cuda')
        x = x.unsqueeze(dim=0)
        x = torch.einsum('nhwc->nchw', x)

        with torch.no_grad():
            _, y, _ = model(x.float(), mask_ratio=0.5)
            y = model.unpatchify(y)
            y = torch.einsum('nchw->nhwc', y)

            # 累加浮点型重建结果
            accumulated_recons += y.float()

    # Calculate the average reconstruction and move it back to CPU
    avg_reconstruction = (accumulated_recons / iterations).detach().cpu()

    # 在可视化之前将平均重建结果转换为整型
    avg_reconstruction = torch.clip((avg_reconstruction * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    '''
    # Set up the plot
    plt.rcParams['figure.figsize'] = [24, 24]

    # Plotting the original image
    plt.subplot(1, 2, 1)
    plt.imshow(original_img.numpy())
    plt.title("Original Image", fontsize=16)
    plt.axis('off')

    # Plotting the average reconstruction image
    plt.subplot(1, 2, 2)
    plt.imshow(avg_reconstruction[0].numpy())
    plt.title("Average Reconstruction", fontsize=16)
    plt.axis('off')

    # Display the plot
    plt.show()
    print(np.shape(avg_reconstruction[0].numpy()))
    print(np.shape(img))
    #print(avg_reconstruction[0].numpy()/255.)
    print(img)
    original_img_ori = torch.tensor(img_ori)
    original_img_ori = torch.clip((original_img_ori * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    print(np.shape(original_img))
    psrn_noisy = compare_psnr(avg_reconstruction[0].numpy()/255.,original_img.numpy()/255.)
    psnr_gt = compare_psnr(avg_reconstruction[0].numpy()/255.,original_img_ori.numpy()/255.)
    print ('PSNR_noisy: %f   PSRN_gt: %f ' % (psrn_noisy, psnr_gt), '\r', end='')
    #ssim = 
    #print(psnr)
    #print("SSIM="+ssim)
    '''
    output_img = avg_reconstruction[0].numpy()
    return output_img



In [12]:
# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')

File ‘mae_visualize_vit_large.pth’ already there; not retrieving.

<All keys matched successfully>
Model loaded.


In [13]:

from PIL import Image
import os

def process_image_with_model(input_image_path, output_image_path, model, iterations=50):
    img = Image.open(input_image_path)
    img = np.array(img) / 255.
    assert img.shape == (224, 224, 3)
    # normalize by ImageNet mean and std
    img = img - imagenet_mean
    img = img / imagenet_std
    processed_img = MAE_Self2Self_like(img, model, iterations)
    # 将数组的数据类型转换为 np.uint8
    processed_img = processed_img.astype(np.uint8)
    processed_img_pil = Image.fromarray(processed_img)
    processed_img_pil.save(output_image_path)


def batch_process_images_with_model(input_folder, output_folder, model, iterations=50):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for image_name in os.listdir(input_folder):
        if image_name.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            input_image_path = os.path.join(input_folder, image_name)
            output_image_path = os.path.join(output_folder, image_name)
            # 检查输出文件夹中是否已存在该图片
            if not os.path.exists(output_image_path):
                process_image_with_model(input_image_path, output_image_path, model, iterations)
            #process_image_with_model(input_image_path, output_image_path, model, iterations)
            #print(f'Processed image {image_name}')
'''

def batch_process_images_with_model(input_folder, output_folder, model, iterations=50):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # 获取所有有效图片文件
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

    # 使用tqdm创建进度条
    for image_name in tqdm(image_files, desc="Processing images"):
        input_image_path = os.path.join(input_folder, image_name)
        output_image_path = os.path.join(output_folder, image_name)
        process_image_with_model(input_image_path, output_image_path, model, iterations)
        # 已移除打印文件名，进度条将显示进度
'''''' 
def batch_process_images_with_model(input_folder, output_folder, model, iterations=50):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # 使用tqdm创建进度条
    for image_name in tqdm(os.listdir(input_folder), desc="Processing images"):
        if image_name.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            input_image_path = os.path.join(input_folder, image_name)
            output_image_path = os.path.join(output_folder, image_name)
            process_image_with_model(input_image_path, output_image_path, model, iterations)
            print(f'Processed image {image_name}')
        # 已移除打印文件名，进度条将显示进度
''' 


'\n\ndef batch_process_images_with_model(input_folder, output_folder, model, iterations=50):\n    if not os.path.exists(output_folder):\n        os.makedirs(output_folder)\n\n    # 获取所有有效图片文件\n    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith((\'.png\', \'.jpg\', \'.jpeg\', \'.gif\', \'.bmp\'))]\n\n    # 使用tqdm创建进度条\n    for image_name in tqdm(image_files, desc="Processing images"):\n        input_image_path = os.path.join(input_folder, image_name)\n        output_image_path = os.path.join(output_folder, image_name)\n        process_image_with_model(input_image_path, output_image_path, model, iterations)\n        # 已移除打印文件名，进度条将显示进度\n \ndef batch_process_images_with_model(input_folder, output_folder, model, iterations=50):\n    if not os.path.exists(output_folder):\n        os.makedirs(output_folder)\n\n    # 使用tqdm创建进度条\n    for image_name in tqdm(os.listdir(input_folder), desc="Processing images"):\n        if image_name.lower().endswith((\'.png\', \'.jpg\'

In [19]:

# 示例使用
input_folder = 'CBSD68-224-Poisson-lam/lam40'  # 修改为您的输入文件夹路径
output_folder = 'CBSD68-224-Poisson-lam-mae/lam40'  # 修改为您的输出文件夹路径
# This is an MAE model trained with an extra GAN loss for more realistic generation (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth

chkpt_dir = 'mae_visualize_vit_large_ganloss.pth'
model_mae_gan = prepare_model('mae_visualize_vit_large_ganloss.pth', 'mae_vit_large_patch16')
print('Model loaded.')

batch_process_images_with_model(input_folder, output_folder, model_mae_gan, iterations=50)


File ‘mae_visualize_vit_large_ganloss.pth’ already there; not retrieving.

<All keys matched successfully>
Model loaded.
