In [1]:
import os
import torch
import numpy as np
from PIL import Image
from pytorch_msssim import ms_ssim, ssim
import matplotlib.pyplot as plt
import pandas as pd
from util.metric import MultiScaleSSIM, psnr

In [2]:
# 파라미터
png_folder = 'data/kodim/png'
rnn_folder = 'data/kodim/rnn'
model_path = 'checkpoint/tiny-imagenet-200-ConvGRUCell/batch32-lr0.0005-l1-06_18_13_46/_best_model_epoch_0192.pth'
reconstruction_metohod = 'one_shot'
rnn_model = 'ConvGRUCell'

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [3]:
# 바뀌지 않는 세팅
png_images_path = os.listdir(png_folder)
png_images_path.sort()

rnn_images_path = os.listdir(rnn_folder)
rnn_images_path.sort()

iters = range(1, 33)
results_dict = {sampling: [] for sampling in iters}
codes_dict = {}
decoded_images_dict = {}

if rnn_model == 'ConvGRUCell':
    from modules import GRU_network as network
else:
    from modules import LSTM_network as network

print('','='*120,'\n ||\tdevice: {}\n'.format(device), "="*120, "\n")
decoder = network.DecoderCell().to(device)
decoder.eval()

checkpoint = torch.load(model_path)
decoder.load_state_dict(checkpoint['decoder'])

 ||	device: cuda



<All keys matched successfully>

In [4]:
# image 및 codes 로드
png_images = {}
for file_name in png_images_path:
    if file_name.lower().endswith('.png'):
        png_path = os.path.join(png_folder, file_name)
        codes_dict[file_name.split('.')[0]] = []
        decoded_images_dict[file_name.split('.')[0]] = []
        with Image.open(png_path) as img:
            png_images[file_name.split('.')[0]] = np.array(img).transpose(2, 0, 1)

for file_name in rnn_images_path:
    if file_name.lower().endswith('.npz'):
        npz_path = os.path.join(rnn_folder, file_name)
        with np.load(npz_path) as data:
            codes_dict[file_name.split('_')[0]].append([data['shape'], data['codes']])
            # print(npz_path, data['shape'], data['codes'].shape) # >> 32*32*48 / 6144 = 8 -> 즉 code는 8비트, 32, 32이니 실제로는 4, 4패치

In [5]:
# 이미지 디코딩
for k in codes_dict.keys(): # k = 'kodim01' 등의 키
    content = codes_dict[k] # 각 이미지에 대한 content, content[0~32]는 각각 iter, content[i][0~1]은 각각 shape, codes
    print("[decode process] image: ", k)
    # for i in range(len(content)): # i = 1~32의 각 iter
    codes = np.unpackbits(content[-1][1])
    codes = np.reshape(codes, content[-1][0]).astype(np.float32) * 2 - 1

    codes = torch.from_numpy(codes)
    iters, batch_size, channels, height, width = codes.size()
    height = height * 16
    width = width * 16

    with torch.no_grad():
        if rnn_model == 'ConvGRUCell':
            # init gru state
            decoder_h_1 = (torch.zeros(batch_size, 512, height // 16, width // 16)).to(device)
            decoder_h_2 = (torch.zeros(batch_size, 512, height // 8, width // 8)).to(device)
            decoder_h_3 = (torch.zeros(batch_size, 256, height // 4, width // 4)).to(device)
            decoder_h_4 = (torch.zeros(batch_size, 128, height // 2, width // 2)).to(device)
        else:
            ## init lstm state
            decoder_h_1 = (torch.zeros(batch_size, 512, height // 16, width // 16), torch.zeros(batch_size, 512, height // 16, width // 16)).to(device)
            decoder_h_2 = (torch.zeros(batch_size, 512, height // 8, width // 8), torch.zeros(batch_size, 512, height // 8, width // 8)).to(device)
            decoder_h_3 = (torch.zeros(batch_size, 256, height // 4, width // 4), torch.zeros(batch_size, 256, height // 4, width // 4)).to(device)
            decoder_h_4 = (torch.zeros(batch_size, 128, height // 2, width // 2), torch.zeros(batch_size, 128, height // 2, width // 2)).to(device)

        codes = codes.to(device)
        image = torch.zeros(1, 3, height, width) + 0.5

        for iters in range(iters):
            output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
                codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
            image = image + output.data.cpu()
            image_disp = np.squeeze(image[0].numpy().clip(0, 1) * 255.0).astype(np.uint8).transpose(1, 2, 0)
            decoded_images_dict[k].append([k, iters+1, image_disp])

            # print("[] iter: ", iters + 1, )
            # plt.figure(figsize=(10, 5))
            # plt.subplot(1, 2, 1)
            # # plt.imshow(image_org.transpose(1, 2, 0))
            # # plt.title('Org Image')
            # plt.subplot(1, 2, 2)
            # plt.imshow(image_disp)
            # plt.title('Decoded Image')
            # plt.show()

[decode process] image:  kodim01
[decode process] image:  kodim02
[decode process] image:  kodim03
[decode process] image:  kodim04
[decode process] image:  kodim05
[decode process] image:  kodim06
[decode process] image:  kodim07
[decode process] image:  kodim08
[decode process] image:  kodim09
[decode process] image:  kodim10
[decode process] image:  kodim11
[decode process] image:  kodim12
[decode process] image:  kodim13
[decode process] image:  kodim14
[decode process] image:  kodim15
[decode process] image:  kodim16
[decode process] image:  kodim17
[decode process] image:  kodim18
[decode process] image:  kodim19
[decode process] image:  kodim20
[decode process] image:  kodim21
[decode process] image:  kodim22
[decode process] image:  kodim23
[decode process] image:  kodim24


In [6]:
# 이제 png_images 딕셔너리 안에 kodim01 ~ 24까지의 이미지가 있고, decoded_images_dict 딕셔너리 안에도 같은 이미지가 있다.
# decoded_images_dict의 value는 [key값, iter, image]로 구성되어 있다.
results_dict = {}
for i in range(1, 33):
    results_dict[i] = []

for k in png_images.keys():
    org_image = np.expand_dims(png_images[k], axis=0)
    print("[calculate metric] image: ", k)
    for i in range(len(decoded_images_dict[k])):
        decoded_images = decoded_images_dict[k][i][2].transpose(2, 0, 1)
        decoded_images = np.expand_dims(decoded_images, axis=0)
        
        ms_ssim_score_1 = MultiScaleSSIM(org_image, decoded_images, max_val=255)
        ms_ssim_score_2 = ms_ssim(torch.tensor(org_image).float(), torch.tensor(decoded_images).float(), data_range=255, size_average=True).item()
        ssim_score = ssim(torch.tensor(org_image).float(), torch.tensor(decoded_images).float(), data_range=255, size_average=True).item()
        psnr_score = psnr(org_image, decoded_images)

        path = os.path.join(rnn_folder, f"{k}_iter{i+1:02}.npz")
        rnn_size = os.path.getsize(path) * 8
        bpp = rnn_size / (decoded_images.shape[2] * decoded_images.shape[3])

        results_dict[i+1].append([i, bpp, ssim_score, ms_ssim_score_1, ms_ssim_score_2, psnr_score])
        print(f"iter: {i+1:02} | bpp: {bpp:.4f} | SSIM: {ssim_score:.4f} | MS-SSIM1: {ms_ssim_score_1:.4f} | MS-SSIM2: {ms_ssim_score_2:.4f} |PSNR: {psnr_score:.4f}")
        # print("org_image: ", org_image.shape, " | decoded_image: ", decoded_images_dict[k][i][2].shape)

[calculate metric] image:  kodim01
iter: 01 | bpp: 0.1200 | SSIM: 0.5003 | MS-SSIM1: 0.9357 | MS-SSIM2: 0.7875 |PSNR: 29.0016
iter: 02 | bpp: 0.2446 | SSIM: 0.6432 | MS-SSIM1: 0.9610 | MS-SSIM2: 0.9015 |PSNR: 29.9107
iter: 03 | bpp: 0.3705 | SSIM: 0.7151 | MS-SSIM1: 0.9682 | MS-SSIM2: 0.9345 |PSNR: 30.4169
iter: 04 | bpp: 0.4955 | SSIM: 0.7612 | MS-SSIM1: 0.9739 | MS-SSIM2: 0.9513 |PSNR: 30.7462
iter: 05 | bpp: 0.6205 | SSIM: 0.8006 | MS-SSIM1: 0.9770 | MS-SSIM2: 0.9602 |PSNR: 31.0786
iter: 06 | bpp: 0.7456 | SSIM: 0.8309 | MS-SSIM1: 0.9800 | MS-SSIM2: 0.9678 |PSNR: 31.3792
iter: 07 | bpp: 0.8706 | SSIM: 0.8544 | MS-SSIM1: 0.9826 | MS-SSIM2: 0.9727 |PSNR: 31.6880
iter: 08 | bpp: 0.9956 | SSIM: 0.8740 | MS-SSIM1: 0.9848 | MS-SSIM2: 0.9768 |PSNR: 32.0014
iter: 09 | bpp: 1.1207 | SSIM: 0.8906 | MS-SSIM1: 0.9866 | MS-SSIM2: 0.9805 |PSNR: 32.3129
iter: 10 | bpp: 1.2457 | SSIM: 0.9035 | MS-SSIM1: 0.9881 | MS-SSIM2: 0.9832 |PSNR: 32.6073
iter: 11 | bpp: 1.3708 | SSIM: 0.9149 | MS-SSIM1: 0.989

In [7]:
results = []
for datas in results_dict.values():
    data = [datas[0][0]]

    bpp_values = [r[1] for r in datas]
    ssim_scores = [r[2] for r in datas]
    ms_ssim_scores1 = [r[3] for r in datas]
    ms_ssim_scores2 = [r[4] for r in datas]
    psnr_scores = [r[5] for r in datas]
    
    data.append(sum(bpp_values) / len(bpp_values))
    data.append(sum(ssim_scores) / len(ssim_scores))
    data.append(sum(ms_ssim_scores1) / len(ms_ssim_scores1))
    data.append(sum(ms_ssim_scores2) / len(ms_ssim_scores2))
    data.append(sum(psnr_scores) / len(psnr_scores))

    results.append(data)

# DataFrame 생성
df = pd.DataFrame(results, columns=['iter', 'bpp', 'SSIM', 'MS-SSIM1', 'MS-SSIM2', 'PSNR'])

# 엑셀 파일로 저장
excel_path = f'/home/vision/models/AdvancedDigitalSignalProcessingCourse-FinalTermProject/data/kodim/results_rnn.xlsx'
df.to_excel(excel_path, index=False)
    
    # print(f'{sampling} 샘플링 팩터에 대한 결과가 {excel_path}에 저장되었습니다.')
# JPEG 폴더의 모든 파일에 대해 MS-SSIM 계산


# for file_name in os.listdir(jpeg_folder):
#     if file_name.lower().endswith('.jpeg'):
#         # JPEG 파일 경로
#         jpeg_path = os.path.join(jpeg_folder, file_name)
        
#         # 원본 PNG 파일 경로
#         original_file_name = file_name.split('_')[0] + '.png'
#         original_path = os.path.join(png_folder, original_file_name)
        
#         # 이미지 열기
#         with Image.open(jpeg_path) as img_jpeg, Image.open(original_path) as img_original:
#             # 이미지를 Tensor로 변환
#             img_jpeg = torch.from_numpy(np.array(img_jpeg)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
#             img_original = torch.from_numpy(np.array(img_original)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
            
#             # MS-SSIM 계산
#             score = ms_ssim(img_jpeg, img_original, data_range=1.0).item()
            
#             # 파일 크기 계산 (bytes -> bits)
#             jpeg_size = os.path.getsize(jpeg_path) * 8
#             bpp = jpeg_size / (img_jpeg.shape[2] * img_jpeg.shape[3])
            
#             # 결과 저장
#             results.append((bpp, score, file_name))

# # 결과를 bpp 기준으로 정렬
# results.sort()

# # bpp와 MS-SSIM 스코어 추출
# bpp_values = [r[0] for r in results]
# ms_ssim_scores = [r[1] for r in results]

# # 결과 출력
# for bpp, score, file_name in results:
#     print(f'파일: {file_name}, bpp: {bpp:.4f}, MS-SSIM: {score:.4f}')

# # 그래프 그리기
# plt.figure(figsize=(10, 6))
# plt.plot(bpp_values, ms_ssim_scores, marker='o')
# plt.xlabel('Bits Per Pixel (bpp)')
# plt.ylabel('MS-SSIM Score')
# plt.title('bpp vs MS-SSIM Score')
# plt.grid(True)
# plt.show()