1. 신경망 기반 이미지 압축 모델들을 제공하는 라이브러리 CompressAI를 설치합니다.

In [7]:
%pip install compressai==1.2.8

Collecting compressai==1.2.8
  Downloading compressai-1.2.8.tar.gz (183 kB)
     ------------------------------------- 183.8/183.8 kB 10.8 MB/s eta 0:00:00
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting einops
  Using cached einops-0.8.1-py3-none-any.whl (64 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.7-cp311-cp311-win_amd64.whl (8.1 MB)
Collecting numpy<2.0,>=1.21.0
  Downloading numpy-1.26.4-cp311-cp311-win_amd64.whl (15.8 MB)
     --------------------------------------- 15.8/15.8 MB 50.3 MB/s eta 0:00:00
Collecting pandas
  Using cached pandas-2.3.3-cp311-cp311-win_amd64.whl (11.3 MB)
Collecting pybind11>=2.6.0
  Using cached pybind11-3.0.1-py3-none-any.whl (293 kB)
Collecting pyto

ERROR: Could not install packages due to an OSError: [WinError 2] 지정된 파일을 찾을 수 없습니다: 'c:\\Python311\\Scripts\\wheel.exe' -> 'c:\\Python311\\Scripts\\wheel.exe.deleteme'


[notice] A new release of pip available: 22.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


2. CompressAI 라이브러리로부터 코덱 호출 (가중치 다운로드 시 시간 소요될 수 있음)

In [None]:
from compressai.zoo import image_models

QP = 1 # [1, ..., 8] 미팅 때 논의했던 quality point, 높을 수록 화질 향상, 비트량 증가

codec_config = image_models['mbt2018-mean']    # CompressAI가 제공하는 코덱 모델 중 mbt2018-mean 선택

codec = codec_config(quality=QP, metric="mse", pretrained=True, progress=True) # 원하는 QP에 대해 학습된 모델 로딩
codec = codec.eval()    # 평가 모드로 전환
codec.update()          # 코덱 초기화 필요

3. 예시 이미지 로딩을 위한 유틸 함수 정의

In [None]:
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def plot_img(torch_img):
    if len(torch_img.shape) == 4:
        torch_img = make_grid(torch_img)
    plt.imshow(torch_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

def PSNR(input1, input2):
    mse = torch.mean((input1 - input2) ** 2)
    psnr = 20 * torch.log10(1 / torch.sqrt(mse))
    return psnr.item()

def save_torch_image(img, save_path):
    img = img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    img = np.clip(np.rint(img * 255), 0, 255).astype(np.uint8)
    Image.fromarray(img).save(save_path)

def load_torch_image(path):
    input_image = Image.open(path).convert('RGB')
    input_image = np.asarray(input_image).astype('float64').transpose(2, 0, 1)
    input_image = torch.from_numpy(input_image).type(torch.FloatTensor)
    input_image = input_image.unsqueeze(0)/255
    return input_image

4. 예시 이미지 확인

In [None]:
img = load_torch_image('./kodim01.png')
print(img.shape)
plot_img(img)

5. 이미지 압축 및 복원

In [None]:
# CompressAI의 이미지 압축 모델들은 가로 세로가 64의 배수일 때 정상 동작합니다.
# 압축 전후로 패딩 및 크롭이 필요합니다. (64의 배수로 만들어주기 위해)

def get_padded_img(img, p=64):
    height, width = img.shape[-2:]
    new_h = (height + p - 1) // p * p
    new_w = (width + p - 1) // p * p
    padding_l = 0
    padding_r = new_w - width - padding_l
    padding_t = 0
    padding_b = new_h - height - padding_t
    pad_info = (padding_l, padding_r, padding_t, padding_b)
    x_padded = torch.nn.functional.pad(
        img,
        pad_info,
        mode="constant", value=0,
    )
    return x_padded, pad_info

def get_cropped_img(padded_img, pad_info):
    reverse_pad_info = tuple(-p for p in pad_info)
    cropped_img = torch.nn.functional.pad(padded_img, reverse_pad_info)
    return cropped_img

with torch.no_grad():
    # Padding (+ get padding information)
    img_padded, pad_info = get_padded_img(img)

    # Encoding (압축, img -> strings)
    compressed = codec.compress(img_padded)
    strings = compressed['strings']
    shape = compressed['shape'] # 디코딩을 위해 필요한 부가 정보

    # Decoding (복원, strings -> img)
    decompressed = codec.decompress(strings=strings, shape=shape)
    decoded_img = decompressed['x_hat']

    # Cropping (using padding information)
    cropped_decoded_img = get_cropped_img(decoded_img, pad_info)

# Visualization (with Original image)
plot_img(torch.cat([img, cropped_decoded_img], dim=0))

6. 압축 결과 요약

In [None]:
psnr = PSNR(img, cropped_decoded_img)
num_pixels = img.shape[-2] * img.shape[-1]
bpp = sum(len(s[0]) * 8 for s in strings) / num_pixels   # bits per pixel : 데이터 양을 영상 해상도로 정규화한 값
print(f"PSNR: {psnr:.3f} | Bpp: {bpp:.3f}")

위의 내용을 공부한 후에, 압축된 영상에 대한 데이터셋을 뽑으려면?

In [None]:
################# 유틸 함수 정의 ###########
from pathlib import Path
import os
import torch
import numpy as np
from torchvision.utils import make_grid
from PIL import Image
import matplotlib.pyplot as plt
import tqdm
from compressai.zoo import image_models

def plot_img(torch_img):
    if len(torch_img.shape) == 4:
        torch_img = make_grid(torch_img)
    plt.imshow(torch_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

def PSNR(input1, input2):
    mse = torch.mean((input1 - input2) ** 2)
    psnr = 20 * torch.log10(1 / torch.sqrt(mse))
    return psnr.item()

def save_torch_image(img, save_path):
    img = img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    img = np.clip(np.rint(img * 255), 0, 255).astype(np.uint8)
    Image.fromarray(img).save(save_path)

def load_torch_image(path):
    input_image = Image.open(path).convert('RGB')
    input_image = np.asarray(input_image).astype('float64').transpose(2, 0, 1)
    input_image = torch.from_numpy(input_image).type(torch.FloatTensor)
    input_image = input_image.unsqueeze(0)/255
    return input_image

def get_padded_img(img, p=64):
    height, width = img.shape[-2:]
    new_h = (height + p - 1) // p * p
    new_w = (width + p - 1) // p * p
    padding_l = 0
    padding_r = new_w - width - padding_l
    padding_t = 0
    padding_b = new_h - height - padding_t
    pad_info = (padding_l, padding_r, padding_t, padding_b)
    x_padded = torch.nn.functional.pad(
        img,
        pad_info,
        mode="constant", value=0,
    )
    return x_padded, pad_info

def get_cropped_img(padded_img, pad_info):
    reverse_pad_info = tuple(-p for p in pad_info)
    cropped_img = torch.nn.functional.pad(padded_img, reverse_pad_info)
    return cropped_img
############################################


################# 수정 필요 ################
data_path = Path('./kodak24')               # path for your dataset
save_path = Path('./kodak24_compressed')    # path for compressed dataset
############################################

device = 'cuda' if torch.cuda.is_available() else 'cpu'
QP_list = [1, 2] # 테스트하고자 하는 QPs
os.makedirs(save_path, exist_ok=True)

for QP in QP_list:  # 테스트하고자 하는 각 QP에 대하여.
    os.makedirs(save_path / f'qp{QP}', exist_ok=True)

    codec_config = image_models['mbt2018-mean']
    codec = codec_config(quality=QP, metric="mse", pretrained=True, progress=True)
    codec = codec.to(device)
    codec = codec.eval()
    codec.update()

    file_paths = img_paths = sorted([
        p for p in data_path.iterdir()
        if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".bmp", ".tiff"]
        ])

    psnrs = []
    bpps = []
    for img_path in file_paths:
        img = load_torch_image(img_path).to(device)

        with torch.no_grad():
            # Padding (+ get padding information)
            img_padded, pad_info = get_padded_img(img)

            # Encoding (압축, img -> strings)
            compressed = codec.compress(img_padded)
            strings = compressed['strings']
            shape = compressed['shape'] # 디코딩을 위해 필요한 부가 정보

            # Decoding (복원, strings -> img)
            decompressed = codec.decompress(strings=strings, shape=shape)
            decoded_img = decompressed['x_hat']

            # Cropping (using padding information)
            cropped_decoded_img = get_cropped_img(decoded_img, pad_info)
            num_pixels = img.shape[-2] * img.shape[-1]
            bpp = sum(len(s[0]) * 8 for s in strings) / num_pixels   # bits per pixel : 데이터 양을 영상 해상도로 정규화한 값
            bpps.append(bpp)
            psnrs.append(PSNR(img, cropped_decoded_img))
            save_torch_image(img, save_path / f'qp{QP}' / img_path.name)

    print(f"QP: {QP} completed, averaged PSNR: {sum(psnrs) / len(file_paths):.3f} | Bpp: {sum(bpps) / len(file_paths):.4f}")
