<a href="https://colab.research.google.com/github/santoshpremi/Perceptual_Learned_Image_Compression/blob/main/Colab_PLIC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Perceptual Learned Image Compression
Introduction:
With most of the internet traffic being image or (esp.) video data, image compression algorithms
are essential to being able to handle this large amount of data. Neural Image Compression (NIC)
methods have been shown to achieve significantly better compression rates compared to
handcrafted algorithms [1]. However, both approaches can suffer from artifacts or overly

---


smoothed images.
Perceptual Image Compression [2, 3] tries to generate reconstructions that are perceived to be of
high quality, even if the actual pixel-values might differ from the original image.
Goals:
Design and implement a NIC method based on a SOTA method
Explore different approaches for optimizing for perceptual quality
Evaluate on benchmark datasets in terms of distortion, perceptual quality and speed
Desirable Experience:
Knowledge of Computer Vision, Deep Learning and Compression
Experience in the PyTorch or TensorFlow framework

References
1. He, Dailan, Ziming Yang, Hongjiu Yu, Tongda Xu, Jixiang Luo, Yuan Chen, Chenjian Gao, Xinjie Shi,
Hongwei Qin, and Yan Wang. “PO-ELIC: Perception-Oriented Efficient Learned Image Coding.” In
2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 1763–
68. New Orleans, LA, USA: IEEE, 2022. https://doi.org/10.1109/CVPRW56347.2022.00187.
2. Mentzer, Fabian, George Toderici, Michael Tschannen, and Eirikur Agustsson. “High-Fidelity
Generative Image Compression.” arXiv, October 23, 2020.
https://doi.org/10.48550/arXiv.2006.09965.
3. Ning, Peirong, Wei Jiang, and Ronggang Wang. “HFLIC: Human Friendly Perceptual Learned Image
Compression with Reinforced Transform.” arXiv, May 18, 2023.
https://doi.org/10.48550/arXiv.2305.07519.





## HFLIC: Human Friendly Perceptual Learned Image Compression with Reinforced Transform

This notebook implements **HFLIC (Human Friendly Perceptual Learned Image Compression)** based on the paper "HFLIC: Human Friendly Perceptual Learned Image Compression with Reinforced Transform" by extending ELIC with reinforced transform blocks for better perceptual quality.

In [None]:
import math
import io
import torch
import json
import time
import os
import sys
from pathlib import Path
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding
!pip install compressai
!pip install lpips

# Add project path for imports
sys.path.append('/content')



In [None]:
import os
import inspect
import compressai

# Get the path to the compressai package
compressai_path = os.path.dirname(inspect.getfile(compressai))
print(f"CompressAI package path: {compressai_path}")

# List contents of the compressai directory
print("Contents of compressai package:")
for item in os.listdir(compressai_path):
    print(f"- {item}")

# Check specifically for the 'metrics' submodule
metrics_path = os.path.join(compressai_path, 'metrics')
print(f"\nChecking for compressai/metrics at: {metrics_path}")
if os.path.exists(metrics_path) and os.path.isdir(metrics_path):
    print("compressai/metrics directory exists.")
    print("Contents of compressai/metrics:")
    for item in os.listdir(metrics_path):
        print(f"  - {item}")
    # Try a direct import again to see if it works after explicit path check
    try:
        from compressai import metrics
        print("Successfully imported compressai.metrics.")
    except Exception as e:
        print(f"Failed to import compressai.metrics even after path check: {e}")
else:
    print("compressai/metrics directory does NOT exist.")


CompressAI package path: /usr/local/lib/python3.12/dist-packages/compressai
Contents of compressai package:
- typing
- __init__.py
- datasets
- layers
- ans.cpython-312-x86_64-linux-gnu.so
- sadl_codec
- registry
- latent_codecs
- _CXX.cpython-312-x86_64-linux-gnu.so
- entropy_models
- losses
- transforms
- zoo
- optimizers
- __pycache__
- version.py
- utils
- models
- ops

Checking for compressai/metrics at: /usr/local/lib/python3.12/dist-packages/compressai/metrics
compressai/metrics directory does NOT exist.


## **Train HFLIC: Vimeo90K dataset and save weights**

Training HFLIC model on Vimeo90K dataset with perceptual loss:
- **Loss = 0.1·MSE + 0.9·LPIPS + 10.0·λ·BPP**
- Model: HFLIC with N=192, M=320
- Training on Vimeo90K test split (for fine-tuning)
- Saving checkpoints for evaluation


In [None]:
from google.colab import drive
drive.mount('/content/drive')


!ls "/content"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Mounted at /content/drive
drive  sample_data


In [None]:
# Write HFLIC model and layer files to disk
import os

# Create necessary directories
os.makedirs('/content/modules', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)

# Write layers.py
layers_code = '''"""
Layer modules for HFLIC/ELIC architecture
Based on ELIC: Efficient Learned Image Compression with Unevenly Grouped Space-Channel Contextual Adaptive Coding
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class GDN(nn.Module):
    """Generalized Divisive Normalization"""
    def __init__(self, in_channels, inverse=False, beta_min=1e-6, gamma_init=0.1):
        super(GDN, self).__init__()
        self.inverse = inverse
        self.beta_min = beta_min
        beta = torch.ones(in_channels)
        gamma = torch.eye(in_channels) * gamma_init
        self.register_buffer("beta", beta)
        self.register_parameter("gamma", nn.Parameter(gamma))

    def forward(self, x):
        gamma = self.gamma + self.beta_min
        beta = self.beta + self.beta_min

        if self.inverse:
            norm = torch.matmul(gamma, x.permute(0, 2, 3, 1).unsqueeze(-1))
            norm = norm.squeeze(-1).permute(0, 3, 1, 2)
            return x * torch.sqrt(norm)
        else:
            norm = torch.matmul(gamma, (x ** 2).permute(0, 2, 3, 1).unsqueeze(-1))
            norm = norm.squeeze(-1).permute(0, 3, 1, 2)
            return x / torch.sqrt(norm + beta.unsqueeze(0).unsqueeze(2).unsqueeze(3))


class IGDN(nn.Module):
    """Inverse Generalized Divisive Normalization"""
    def __init__(self, in_channels, beta_min=1e-6, gamma_init=0.1):
        super(IGDN, self).__init__()
        self.gdn = GDN(in_channels, inverse=True, beta_min=beta_min, gamma_init=gamma_init)

    def forward(self, x):
        return self.gdn(x)


class ResidualBottleneck(nn.Module):
    """Residual bottleneck block for ELIC/HFLIC"""
    def __init__(self, N, M, act=nn.ReLU(inplace=True)):
        super(ResidualBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(N, M, 1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(M, M, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(M, N, 1, stride=1, padding=0)
        self.act = act

    def forward(self, x):
        identity = x
        out = self.act(self.conv1(x))
        out = self.act(self.conv2(out))
        out = self.conv3(out)
        return out + identity


class AttentionBlock(nn.Module):
    """Attention mechanism for reinforced transform in HFLIC"""
    def __init__(self, channels):
        super(AttentionBlock, self).__init__()
        self.channels = channels
        self.conv = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        attention = torch.sigmoid(self.conv(x))
        return x * attention


class ReinforcedTransform(nn.Module):
    """Reinforced Transform module for HFLIC - no downsampling, just attention + transformation"""
    def __init__(self, N, M):
        super(ReinforcedTransform, self).__init__()
        self.attention = AttentionBlock(N)
        self.conv = nn.Sequential(
            nn.Conv2d(N, M, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(M, N, 3, stride=1, padding=1)
        )

    def forward(self, x):
        x_attn = self.attention(x)
        x_transformed = self.conv(x_attn)
        return x + x_transformed  # Residual connection


class ContextualAttention(nn.Module):
    """Contextual attention module for adaptive coding"""
    def __init__(self, channels):
        super(ContextualAttention, self).__init__()
        self.conv_query = nn.Conv2d(channels, channels // 4, 1)
        self.conv_key = nn.Conv2d(channels, channels // 4, 1)
        self.conv_value = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.size()
        query = self.conv_query(x).view(B, -1, H * W)
        key = self.conv_key(x).view(B, -1, H * W)
        value = self.conv_value(x).view(B, -1, H * W)

        attention = torch.softmax(torch.bmm(query.transpose(1, 2), key), dim=-1)
        out = torch.bmm(value, attention.transpose(1, 2))
        return out.view(B, C, H, W)
'''

with open('/content/modules/layers.py', 'w') as f:
    f.write(layers_code)

# Write elic.py
elic_code = '''"""
ELIC: Efficient Learned Image Compression with Unevenly Grouped Space-Channel Contextual Adaptive Coding
Base implementation for HFLIC
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.layers import GDN
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from modules.layers import ResidualBottleneck, ContextualAttention


class ELICAnalysis(nn.Module):
    """Analysis transform for ELIC"""
    def __init__(self, N=192, M=320):
        super(ELICAnalysis, self).__init__()
        self.N = N
        self.M = M

        # Initial convolution
        self.conv1 = nn.Conv2d(3, N, 5, stride=2, padding=2)
        self.gdn1 = GDN(N)

        # Residual blocks with attention
        self.res1 = ResidualBottleneck(N, N * 2)
        self.res2 = ResidualBottleneck(N, N * 2)
        self.attn1 = ContextualAttention(N)

        # Downsampling
        self.conv2 = nn.Conv2d(N, N, 5, stride=2, padding=2)
        self.gdn2 = GDN(N)

        self.res3 = ResidualBottleneck(N, N * 2)
        self.res4 = ResidualBottleneck(N, N * 2)
        self.attn2 = ContextualAttention(N)

        # Final transform
        self.conv3 = nn.Conv2d(N, M, 3, stride=1, padding=1)

    def forward(self, x):
        x = self.gdn1(self.conv1(x))
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn1(x)
        x = self.gdn2(self.conv2(x))
        x = self.res3(x)
        x = self.res4(x)
        x = self.attn2(x)
        x = self.conv3(x)
        return x


class ELICSynthesis(nn.Module):
    """Synthesis transform for ELIC"""
    def __init__(self, N=192, M=320):
        super(ELICSynthesis, self).__init__()
        self.N = N
        self.M = M

        # Initial transform
        self.conv1 = nn.Conv2d(M, N, 3, stride=1, padding=1)

        # Residual blocks with attention
        self.attn1 = ContextualAttention(N)
        self.res1 = ResidualBottleneck(N, N * 2)
        self.res2 = ResidualBottleneck(N, N * 2)

        # Upsampling
        self.conv2 = nn.ConvTranspose2d(N, N, 5, stride=2, padding=2, output_padding=1)
        self.igdn1 = GDN(N, inverse=True)

        self.attn2 = ContextualAttention(N)
        self.res3 = ResidualBottleneck(N, N * 2)
        self.res4 = ResidualBottleneck(N, N * 2)

        # Final transform
        self.conv3 = nn.ConvTranspose2d(N, 3, 5, stride=2, padding=2, output_padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.attn1(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.igdn1(self.conv2(x))
        x = self.attn2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.conv3(x)
        return x


class ELIC(nn.Module):
    """
    ELIC: Efficient Learned Image Compression
    Base model for HFLIC
    """
    def __init__(self, N=192, M=320):
        super(ELIC, self).__init__()
        self.N = N
        self.M = M

        self.analysis = ELICAnalysis(N, M)
        self.synthesis = ELICSynthesis(N, M)

        # Entropy models
        self.hyperprior_entropy_bottleneck = EntropyBottleneck(M) # Entropy bottleneck for hyperprior (z)
        self.gaussian_conditional = GaussianConditional(None)

        # Hyper-synthesis network to predict scales and means for y from z_hat
        self.h_s = nn.Sequential(
            nn.Conv2d(M, M * 2, 3, stride=1, padding=1), # Outputs 2*M channels for scales and means
            nn.ReLU(inplace=True),
            nn.Conv2d(M * 2, M * 2, 3, stride=1, padding=1),
        )

    def forward(self, x):
        # Analysis transform
        y = self.analysis(x)

        # Context model for hyperprior (z from abs(y) as in original code)
        z = torch.abs(y)
        z_hat, z_likelihoods = self.hyperprior_entropy_bottleneck(z)

        # Predict scales and means for y from z_hat using hyper-synthesis
        gaussian_params = self.h_s(z_hat)
        scales_hat, means_hat = gaussian_params.chunk(2, 1) # Split into M channels for scales, M for means

        # Gaussian conditional for y
        y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)

        # Synthesis transform
        x_hat = self.synthesis(y_hat)

        return {
            "x_hat": x_hat.clamp(0, 1),
            "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
        }
'''

with open('/content/models/elic.py', 'w') as f:
    f.write(elic_code)

# Write hflic.py
hflic_code = '''"""
HFLIC: Human Friendly Perceptual Learned Image Compression with Reinforced Transform
Extends ELIC with reinforced transform for better perceptual quality
"""

import torch
import torch.nn as nn
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from models.elic import ELIC
from modules.layers import ReinforcedTransform, AttentionBlock


class HFLICAnalysis(nn.Module):
    """Analysis transform for HFLIC with reinforced transform"""
    def __init__(self, N=192, M=320):
        super(HFLICAnalysis, self).__init__()
        from compressai.layers import GDN
        from modules.layers import ResidualBottleneck, ContextualAttention

        self.N = N
        self.M = M

        # Initial convolution
        self.conv1 = nn.Conv2d(3, N, 5, stride=2, padding=2)
        self.gdn1 = GDN(N)

        # Reinforced transform blocks
        self.reinforced1 = ReinforcedTransform(N, N)
        self.reinforced2 = ReinforcedTransform(N, N)

        # Residual blocks with attention
        self.res1 = ResidualBottleneck(N, N * 2)
        self.res2 = ResidualBottleneck(N, N * 2)
        self.attn1 = ContextualAttention(N)

        # Downsampling with reinforced transform
        self.conv2 = nn.Conv2d(N, N, 5, stride=2, padding=2)
        self.gdn2 = GDN(N)

        self.reinforced3 = ReinforcedTransform(N, N)
        self.res3 = ResidualBottleneck(N, N * 2)
        self.res4 = ResidualBottleneck(N, N * 2)
        self.attn2 = ContextualAttention(N)

        # Final transform
        self.conv3 = nn.Conv2d(N, M, 3, stride=1, padding=1)

    def forward(self, x):
        x = self.gdn1(self.conv1(x))
        x = self.reinforced1(x)
        x = self.reinforced2(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn1(x)
        x = self.gdn2(self.conv2(x))
        x = self.reinforced3(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.attn2(x)
        x = self.conv3(x)
        return x


class HFLICSynthesis(nn.Module):
    """Synthesis transform for HFLIC with reinforced transform"""
    def __init__(self, N=192, M=320):
        super(HFLICSynthesis, self).__init__()
        from compressai.layers import GDN
        from modules.layers import ResidualBottleneck, ContextualAttention

        self.N = N
        self.M = M

        # Initial transform
        self.conv1 = nn.Conv2d(M, N, 3, stride=1, padding=1)

        # Residual blocks with attention
        self.attn1 = ContextualAttention(N)
        self.res1 = ResidualBottleneck(N, N * 2)
        self.res2 = ResidualBottleneck(N, N * 2)

        # Upsampling with reinforced transform
        self.conv2 = nn.ConvTranspose2d(N, N, 5, stride=2, padding=2, output_padding=1)
        self.igdn1 = GDN(N, inverse=True)

        self.reinforced1 = ReinforcedTransform(N, N)
        self.attn2 = ContextualAttention(N)
        self.res3 = ResidualBottleneck(N, N * 2)
        self.res4 = ResidualBottleneck(N, N * 2)

        # Final transform
        self.conv3 = nn.ConvTranspose2d(N, 3, 5, stride=2, padding=2, output_padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.attn1(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.igdn1(self.conv2(x))
        x = self.reinforced1(x)
        x = self.attn2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.conv3(x)
        return x


class HFLIC(nn.Module):
    """
    HFLIC: Human Friendly Perceptual Learned Image Compression with Reinforced Transform

    Extends ELIC by incorporating reinforced transform blocks that enhance
    perceptual quality through attention mechanisms.
    """
    def __init__(self, N=192, M=320):
        super(HFLIC, self).__init__()
        self.N = N
        self.M = M

        self.analysis = HFLICAnalysis(N, M)
        self.synthesis = HFLICSynthesis(N, M)

        # Entropy models
        from compressai.entropy_models import EntropyBottleneck, GaussianConditional
        self.hyperprior_entropy_bottleneck = EntropyBottleneck(M) # Entropy bottleneck for hyperprior (z)
        self.gaussian_conditional = GaussianConditional(None)

        # Hyper-synthesis network to predict scales and means for y from z_hat
        self.h_s = nn.Sequential(
            nn.Conv2d(M, M * 2, 3, stride=1, padding=1), # Outputs 2*M channels for scales and means
            nn.ReLU(inplace=True),
            nn.Conv2d(M * 2, M * 2, 3, stride=1, padding=1),
        )

    def forward(self, x):
        # Analysis transform
        y = self.analysis(x)

        # Context model for hyperprior (z from abs(y) as in original code)
        z = torch.abs(y)
        z_hat, z_likelihoods = self.hyperprior_entropy_bottleneck(z)

        # Predict scales and means for y from z_hat using hyper-synthesis
        gaussian_params = self.h_s(z_hat)
        scales_hat, means_hat = gaussian_params.chunk(2, 1) # Split into M channels for scales, M for means

        # Gaussian conditional for y
        y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)

        # Synthesis transform
        x_hat = self.synthesis(y_hat)

        return {
            "x_hat": x_hat.clamp(0, 1),
            "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
        }

    def compress(self, x):
        """Compress image"""
        y = self.analysis(x)
        # For compression, we need the actual quantized latent z_hat for hyperprior
        z = torch.abs(y)
        z_hat = self.hyperprior_entropy_bottleneck.quantize(z)

        # Predict scales and means for y from z_hat
        gaussian_params = self.h_s(z_hat)
        scales_hat, means_hat = gaussian_params.chunk(2, 1)

        # Quantize y using gaussian conditional
        y_hat = self.gaussian_conditional.quantize(y, scales_hat, means=means_hat)

        return y_hat

    def decompress(self, y_string, z_string, shape):
        """Decompress image"""
        # Decompression logic would go here
        # This part requires a full implementation of decoding from bitstreams
        # and is more complex than a simple forward pass.
        # For now, it's a placeholder.
        raise NotImplementedError("Decompression is not fully implemented yet.")
'''

with open('/content/models/hflic.py', 'w') as f:
    f.write(hflic_code)

print("HFLIC model files written successfully!")


HFLIC model files written successfully!


In [None]:
import argparse
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from models.hflic import HFLIC
from datasets.vimeo90k import Vimeo90KSingleFrameDataset
from loss.perceptual_loss import PerceptualLoss
from utils.utils import AverageMeter, compute_psnr, create_train_transform


def parse_args(argv):
    parser = argparse.ArgumentParser(description="HFLIC training loop inside notebook")
    parser.add_argument('--dataset', type=str, required=True, help='Path to Vimeo90K dataset')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--learning-rate', type=float, default=1e-4)
    parser.add_argument('--aux-learning-rate', type=float, default=1e-3)
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--patch-size', type=int, nargs=2, default=(256, 256))
    parser.add_argument('--lambda', dest='lmbda', type=float, default=0.08)
    parser.add_argument('--clip-max-norm', type=float, default=0.1)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--print-freq', type=int, default=100)
    return parser.parse_args(argv)


def configure_optimizers(net, args):
    params, aux_params = [], []
    for name, param in net.named_parameters():
        if 'entropy_bottleneck' in name or 'gaussian_conditional' in name:
            aux_params.append(param)
        else:
            params.append(param)

    main_optim = optim.Adam(params, lr=args.learning_rate)
    aux_optim = optim.Adam(aux_params, lr=args.aux_learning_rate) if aux_params else None

    scheduler = torch.optim.lr_scheduler.StepLR(
        main_optim,
        step_size=max(args.epochs // 4, 1),
        gamma=0.5,
    )
    scheduler_aux = (
        torch.optim.lr_scheduler.StepLR(aux_optim, step_size=max(args.epochs // 4, 1), gamma=0.5)
        if aux_optim is not None
        else None
    )

    return main_optim, aux_optim, scheduler, scheduler_aux


def train_one_epoch(model, criterion, dataloader, optimizer, aux_optimizer, epoch, args):
    model.train()
    device = next(model.parameters()).device

    loss_meter = AverageMeter()
    bpp_meter = AverageMeter()
    psnr_meter = AverageMeter()
    lpips_meter = AverageMeter()

    for i, (images, _) in enumerate(dataloader):
        images = images.to(device)
        optimizer.zero_grad()
        if aux_optimizer is not None:
            aux_optimizer.zero_grad()

        output = model(images)
        loss_dict = criterion(output, images)
        total_loss = loss_dict['loss']
        total_loss.backward()

        if args.clip_max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_max_norm)
        optimizer.step()
        if aux_optimizer is not None:
            aux_optimizer.step()

        batch_size = images.size(0)
        loss_meter.update(total_loss.item(), batch_size)
        bpp_meter.update(loss_dict['bpp_loss'].item(), batch_size)
        lpips_meter.update(loss_dict['lpips_loss'].item(), batch_size)
        psnr_meter.update(compute_psnr(images, output['x_hat']), batch_size)

        if i % args.print_freq == 0:
            print(
                f"Epoch {epoch} | Iter {i}/{len(dataloader)} | "
                f"Loss {loss_meter.avg:.4f} | BPP {bpp_meter.avg:.4f} | "
                f"LPIPS {lpips_meter.avg:.4f} | PSNR {psnr_meter.avg:.2f}"
            )

    return {
        'loss': loss_meter.avg,
        'bpp': bpp_meter.avg,
        'psnr': psnr_meter.avg,
        'lpips': lpips_meter.avg,
        'mse': mse_meter.avg,
    }


def main(argv):
    args = parse_args(argv)
    device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    crop = args.patch_size[0] if isinstance(args.patch_size, (list, tuple)) else args.patch_size
    transform = create_train_transform(crop_size=crop)
    dataset = Vimeo90KSingleFrameDataset(
        root_dir=args.dataset,
        split='train',
        transform=transform
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=device == 'cuda'
    )

    model = HFLIC().to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = PerceptualLoss(lmbda=args.lmbda).to(device)
    optimizer, aux_optimizer, scheduler, scheduler_aux = configure_optimizers(model, args)

    for epoch in range(args.epochs):
        metrics = train_one_epoch(
            model,
            criterion,
            dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            clip_max_norm=args.clip_max_norm,
            print_freq=args.print_freq,
        )
        print(
            "Epoch {}/{} summary: Loss {:.4f} | BPP {:.4f} | LPIPS {:.4f} | PSNR {:.2f} | MSE {:.6f}".format(
                epoch + 1,
                args.epochs,
                metrics['loss'],
                metrics['bpp'],
                metrics['lpips'],
                metrics['psnr'],
                metrics['mse'],
            )
        )

        scheduler.step()
        if scheduler_aux is not None:
            scheduler_aux.step()


if __name__ == '__main__':
    main([
        '--dataset', '/content/drive/My Drive/vimeo_test_clean',
        '--epochs', '100',
        '--cuda',
        '--batch-size', '4',
        '--learning-rate', '1e-4',
        '--aux-learning-rate', '1e-3',
        '--clip-max-norm', '0.1',
        '--print-freq', '50',
    ])



ModuleNotFoundError: No module named 'models'

## **Evaluating HFLIC: Kodak dataset with saved weights**

Evaluating trained HFLIC model on Kodak dataset (24 standard test images) with comprehensive metrics:
- **PSNR**: Peak Signal-to-Noise Ratio (dB)
- **LPIPS**: Learned Perceptual Image Patch Similarity (lower is better)
- **BPP**: Bits Per Pixel (compression efficiency)
- **MSE**: Mean Squared Error

---



In [None]:
# Import necessary libraries
import os
import torch
import lpips
import numpy as np
import math
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from compressai.zoo import image_models
from compressai.losses import RateDistortionLoss

# Define a custom dataset class for Kodak images
class KodakDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0  # Return 0 for compatibility with dataloader (label is not used)

# Utility class to compute running average of the losses
class AverageMeter:
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Function to compute PSNR
def compute_psnr(a, b, max_val: float = 1.0) -> float:
    mse = torch.mean((a - b) ** 2).item()
    if mse == 0:
        return float('inf')
    psnr = -10 * np.log10(mse)
    return psnr

# Function to compute BPP
def compute_bpp(out_net, num_pixels):
    """Compute bits per pixel from likelihoods"""
    bpp = 0.0
    for likelihoods in out_net["likelihoods"].values():
        bpp += torch.log(likelihoods + 1e-10).sum() / (-math.log(2) * num_pixels)
    return bpp.item()

# Function to evaluate the model on Kodak dataset with PSNR, LPIPS, and BPP
def evaluate_model(model, dataloader, device, model_name="Model"):
    model.eval()
    psnr_meter = AverageMeter()
    bpp_meter = AverageMeter()
    lpips_meter = AverageMeter()

    lpips_loss = lpips.LPIPS(net='alex').to(device)

    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            out_net = model(images)

            # Compute PSNR
            psnr = compute_psnr(images, out_net['x_hat'])
            psnr_meter.update(psnr, images.size(0))

            # Compute BPP (corrected: include batch size)
            num_pixels = images.size(0) * images.size(2) * images.size(3)
            bpp = compute_bpp(out_net, num_pixels)
            bpp_meter.update(bpp)

            # Compute LPIPS
            # LPIPS expects values in range [-1, 1]
            images_lpips = images * 2.0 - 1.0
            x_hat_lpips = out_net['x_hat'] * 2.0 - 1.0
            lpips_value = lpips_loss(images_lpips, x_hat_lpips).mean().item()
            lpips_meter.update(lpips_value, images.size(0))

    print(f"\n{model_name} Results:")
    print(f"  PSNR: {psnr_meter.avg:.3f} dB")
    print(f"  BPP:  {bpp_meter.avg:.4f}")
    print(f"  LPIPS: {lpips_meter.avg:.4f}")

    return {
        'PSNR': psnr_meter.avg,
        'BPP': bpp_meter.avg,
        'LPIPS': lpips_meter.avg
    }

def main():
    # Configuration
    checkpoint_path = '/content/checkpoint_best_loss.pth.tar'  # Path to the best checkpoint
    kodak_dataset_path = '/content/drive/My Drive/kodak'  # Path to Kodak dataset
    patch_size = (256, 256)
    batch_size = 8
    quality = 3
    N = 192  # HFLIC channel parameters
    M = 320

    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load Kodak Dataset
    test_transforms = transforms.Compose([
        transforms.Resize(patch_size),
        transforms.ToTensor()
    ])

    kodak_dataset = KodakDataset(root_dir=kodak_dataset_path, transform=test_transforms)
    kodak_dataloader = DataLoader(
        kodak_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    print("=" * 60)
    print("COMPREHENSIVE EVALUATION ON KODAK DATASET")
    print("=" * 60)

    # ============================================
    # 1. Evaluate Pretrained Baseline Model
    # ============================================
    print("\n" + "=" * 60)
    print("1. Evaluating PRETRAINED Baseline Model")
    print("=" * 60)

    from compressai.zoo import image_models
    pretrained_model = image_models["mbt2018-mean"](quality=quality, pretrained=True).to(device)
    pretrained_results = evaluate_model(
        pretrained_model,
        kodak_dataloader,
        device,
        model_name="Pretrained Baseline"
    )

    # ============================================
    # 2. Evaluate Fine-tuned Model
    # ============================================
    print("\n" + "=" * 60)
    print("2. Evaluating FINE-TUNED Model (LPIPS)")
    print("=" * 60)

    try:
        # Initialize HFLIC model exactly as in training
        from models.hflic import HFLIC
        fine_tuned_model = HFLIC(N=N, M=M).to(device)

        checkpoint = torch.load(checkpoint_path, map_location=device)

        # Handle DataParallel if needed
        state_dict = checkpoint['state_dict']
        if any(k.startswith('module.') for k in state_dict.keys()):
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] if k.startswith('module.') else k
                new_state_dict[name] = v
            state_dict = new_state_dict

        # Filter out size-0 parameters (lazy-initialized) - they'll be initialized on first forward
        filtered_state_dict = {}
        model_state_dict = fine_tuned_model.state_dict()
        for k, v in state_dict.items():
            if k in model_state_dict:
                if v.shape == model_state_dict[k].shape:
                    filtered_state_dict[k] = v
                elif v.numel() == 0:
                    # Skip lazy-initialized parameters (they'll be initialized on first forward)
                    print(f"  Skipping lazy-initialized parameter: {k}")
                else:
                    # Size mismatch - try to load anyway or skip
                    print(f"  Warning: Size mismatch for {k}: checkpoint {v.shape} vs model {model_state_dict[k].shape}")

        fine_tuned_model.load_state_dict(filtered_state_dict, strict=False)
        fine_tuned_model.eval()

        fine_tuned_results = evaluate_model(
            fine_tuned_model,
            kodak_dataloader,
            device,
            model_name="Fine-tuned (LPIPS)"
        )
    except Exception as e:
        print(f"\nError loading fine-tuned model: {e}")
        import traceback
        traceback.print_exc()
        fine_tuned_results = {'PSNR': 0.0, 'BPP': 0.0, 'LPIPS': 0.0}

    # ============================================
    # 3. Comparison Table
    # ============================================
    print("\n" + "=" * 60)
    print("3. COMPARISON TABLE")
    print("=" * 60)

    comparison_data = {
        'Model': ['Pretrained Baseline', 'Fine-tuned (LPIPS)'],
        'PSNR (dB)': [f"{pretrained_results['PSNR']:.3f}", f"{fine_tuned_results['PSNR']:.3f}"],
        'BPP': [f"{pretrained_results['BPP']:.4f}", f"{fine_tuned_results['BPP']:.4f}"],
        'LPIPS': [f"{pretrained_results['LPIPS']:.4f}", f"{fine_tuned_results['LPIPS']:.4f}"]
    }

    df = pd.DataFrame(comparison_data)
    print("\n" + df.to_string(index=False))

    # Calculate improvements
    if fine_tuned_results['PSNR'] > 0:
        print("\n" + "-" * 60)
        print("IMPROVEMENTS (Fine-tuned vs Pretrained):")
        print("-" * 60)
        psnr_diff = fine_tuned_results['PSNR'] - pretrained_results['PSNR']
        bpp_diff = fine_tuned_results['BPP'] - pretrained_results['BPP']
        lpips_diff = fine_tuned_results['LPIPS'] - pretrained_results['LPIPS']

        print(f"  PSNR:  {psnr_diff:+.3f} dB ({'↑ Improved' if psnr_diff > 0 else '↓ Worse'})")
        print(f"  BPP:   {bpp_diff:+.4f} ({'↓ Improved' if bpp_diff < 0 else '↑ Worse'})")
        print(f"  LPIPS: {lpips_diff:+.4f} ({'↓ Improved' if lpips_diff < 0 else '↑ Worse'})")

    print("\n" + "=" * 60)
    print("Evaluation Complete!")
    print("=" * 60)

if __name__ == "__main__":
    main()

Using device: cuda
COMPREHENSIVE EVALUATION ON KODAK DATASET

1. Evaluating PRETRAINED Baseline Model
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-3-723404a8.pth.tar" to /root/.cache/torch/hub/checkpoints/mbt2018-mean-3-723404a8.pth.tar


100%|██████████| 27.6M/27.6M [00:02<00:00, 12.4MB/s]


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


100%|██████████| 233M/233M [00:01<00:00, 244MB/s]


Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth

Pretrained Baseline Results:
  PSNR: 30.708 dB
  BPP:  0.3036
  LPIPS: 0.2062

2. Evaluating FINE-TUNED Model (LPIPS)
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth

Fine-tuned (LPIPS) Results:
  PSNR: 6.554 dB
  BPP:  102.0340
  LPIPS: 1.0055

3. COMPARISON TABLE

              Model PSNR (dB)      BPP  LPIPS
Pretrained Baseline    30.708   0.3036 0.2062
 Fine-tuned (LPIPS)     6.554 102.0340 1.0055

------------------------------------------------------------
IMPROVEMENTS (Fine-tuned vs Pretrained):
------------------------------------------------------------
  PSNR:  -24.154 dB (↓ Worse)
  BPP:   +101.7305 (↑ Worse)
  LPIPS: +0.7993 (↑ Worse)

Evaluation Complete!
