# HiFiC (Low-memory) — Jupyter Notebook

This notebook is a low-memory, GPU-friendly reproduction skeleton for **HiFiC (NeurIPS 2020)**.
- Designed to run on a **4 GB GPU** (or CPU) with conservative defaults.
- Includes robust LPIPS initialization (falls back if offline), AMP, gradient accumulation, and a flat-image dataloader.

Use this notebook to debug and run small experiments. For paper-scale training you will need larger GPUs and longer runs.


In [1]:
# Cell 1 — Install required packages (run once)
import sys
print('Python executable:', sys.executable)
# try:
#     # Install only if missing — noisy but useful in fresh envs
#     get_ipython().system('pip install --quiet torch torchvision lpips pytorch-fid pytorch-msssim tqdm Pillow scikit-image')
# except Exception as e:
#     print('Install step skipped or failed:', e)


Python executable: E:\Software_installation\Anaconda\anaconda3\envs\ipc\python.exe


In [2]:
# Cell 2 — Imports & device
import os, math, time, random
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from PIL import Image

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)


Device: cuda


In [3]:
# Cell 3 — Low-memory configuration
TRAIN_DIR = 'data/train'   # point to your images (flat folder or ImageFolder layout)
TEST_DIR = 'data/test'
OUT_DIR = 'outputs'
CHECKPOINT_DIR = 'checkpoints'
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

BATCH_SIZE = 1
CROP = 128
LEARNING_RATE = 1e-4
MAX_STEPS = 2000
SAVE_EVERY = 1000
LAMBDA_RATE = 1e-2
BETA_GAN = 0.1   # start with 0 (no GAN) to save memory
KM = 1.0
KP = 0.0         # disable LPIPS during training to save memory (can compute at eval)
N_RES = 3
BASE_CHANNELS = 24
Z_CHANNELS = 64
ACCUM_STEPS = 1  # gradient accumulation steps
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


In [4]:
# Cell 4 — Utilities: FlatImageFolder and demo image creator
import glob
IMG_EXTS = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')

class FlatImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.files = []
        for ext in IMG_EXTS:
            self.files += glob.glob(os.path.join(root, '**', f'*{ext}'), recursive=True)
        self.files = sorted(self.files)
        if len(self.files) == 0:
            raise FileNotFoundError(f'No images found in {root}')
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        path = self.files[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0

def make_dataloader(path, batch_size=BATCH_SIZE, crop=CROP, shuffle=True):
    tf = transforms.Compose([
        transforms.RandomResizedCrop(crop, scale=(0.5,1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    if os.path.exists(path):
        # prefer ImageFolder if it has class subfolders
        try:
            from torchvision import datasets
            has_subdirs = any(os.path.isdir(os.path.join(path, d)) for d in os.listdir(path))
            if has_subdirs:
                ds = datasets.ImageFolder(path, transform=tf)
            else:
                ds = FlatImageFolder(path, transform=tf)
        except Exception:
            ds = FlatImageFolder(path, transform=tf)
    else:
        raise FileNotFoundError(f'Path {path} does not exist')
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True)


def create_demo_images(folder='demo_images', n=8, size=256):
    os.makedirs(folder, exist_ok=True)
    for i in range(n):
        arr = (np.random.rand(size, size, 3) * 255).astype(np.uint8)
        Image.fromarray(arr).save(f"{folder}/img_{i:03d}.png")
    return folder


In [5]:
# Cell 5 — Models (low-memory sizes)
class ChannelNorm(nn.Module):
    def __init__(self, channels, eps=1e-5):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(channels))
        self.beta = nn.Parameter(torch.zeros(channels))
        self.eps = eps
    def forward(self, x):
        mu = x.mean(dim=1, keepdim=True)
        var = ((x - mu) ** 2).mean(dim=1, keepdim=True)
        sigma = torch.sqrt(var + self.eps)
        x_norm = (x - mu) / sigma
        a = self.alpha.view(1, -1, 1, 1)
        b = self.beta.view(1, -1, 1, 1)
        return x_norm * a + b

def conv_block(in_ch, out_ch, stride=1, kernel=3, norm=True):
    pad = kernel // 2
    layers = [nn.Conv2d(in_ch, out_ch, kernel, stride=stride, padding=pad)]
    if norm:
        layers.append(ChannelNorm(out_ch))
    layers.append(nn.ReLU(inplace=True))
    return nn.Sequential(*layers)

class Encoder(nn.Module):
    def __init__(self, in_ch=3, base=BASE_CHANNELS):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, base, 7, padding=3), ChannelNorm(base), nn.ReLU(inplace=True),
            nn.Conv2d(base, base*2, 3, stride=2, padding=1), ChannelNorm(base*2), nn.ReLU(inplace=True),
            nn.Conv2d(base*2, base*4, 3, stride=2, padding=1), ChannelNorm(base*4), nn.ReLU(inplace=True),
            nn.Conv2d(base*4, base*8, 3, stride=2, padding=1), ChannelNorm(base*8), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm1 = ChannelNorm(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm2 = ChannelNorm(ch)
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.norm1(out))
        out = self.conv2(out)
        out = self.norm2(out)
        return x + out

class Generator(nn.Module):
    def __init__(self, out_ch=3, base=BASE_CHANNELS, n_res=N_RES):
        super().__init__()
        self.initial = nn.Conv2d(base*8, base*8, 3, padding=1)
        self.res_blocks = nn.Sequential(*[ResidualBlock(base*8) for _ in range(n_res)])
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base*8, base*4, 3, padding=1), ChannelNorm(base*4), nn.ReLU())
        self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base*4, base*2, 3, padding=1), ChannelNorm(base*2), nn.ReLU())
        self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base*2, base, 3, padding=1), ChannelNorm(base), nn.ReLU())
        self.final = nn.Conv2d(base, out_ch, 3, padding=1)
    def forward(self, y_quant):
        x = self.initial(y_quant)
        x = self.res_blocks(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        return torch.clamp(torch.sigmoid(self.final(x)), 0.0, 1.0)

class HyperEncoder(nn.Module):
    def __init__(self, in_ch=BASE_CHANNELS*8, z_channels=Z_CHANNELS):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, in_ch//2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(in_ch//2, z_channels, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(z_channels, z_channels, 3, stride=2, padding=1),
        )
    def forward(self, y):
        return self.net(y)

class HyperDecoder(nn.Module):
    def __init__(self, z_channels=Z_CHANNELS, out_ch=BASE_CHANNELS*8*2):
        super().__init__()
        mid = out_ch//2
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_channels, z_channels, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(z_channels, mid, 4, stride=2, padding=1),
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, in_ch=3, cond_ch=3):
        super().__init__()
        def conv(cin, cout, k=4, s=2, p=1):
            return nn.Sequential(nn.Conv2d(cin, cout, k, stride=s, padding=p), nn.LeakyReLU(0.2, inplace=True))
        self.net = nn.Sequential(
            conv(in_ch + cond_ch, 64),
            conv(64, 128),
            conv(128, 256),
            conv(256, 512),
            nn.Conv2d(512, 1, 1)
        )
    def forward(self, x, y_cond):
        inp = torch.cat([x, y_cond], dim=1)
        return self.net(inp)   # RETURN LOGITS (no sigmoid)



In [6]:
# Cell 6 — Quantization, entropy helpers, and safe ops
def add_uniform_noise(x):
    return x + (torch.rand_like(x) - 0.5)

def round_st(x):
    return (x.round() - x).detach() + x

def gaussian_log_prob(x, mu, sigma):
    eps = 1e-9
    var = (sigma + eps) ** 2
    log_prob = -0.5 * ( (x - mu)**2 / var + torch.log(2 * math.pi * var) )
    return log_prob

def estimate_bits(y_quant, mu, sigma):
    logp = gaussian_log_prob(y_quant, mu, sigma)
    bits = -logp.sum() / math.log(2.0)
    return bits


In [7]:
# Cell 7 — Losses with robust LPIPS init
import warnings
try:
    import lpips
except Exception:
    lpips = None

def make_lpips(device):
    if lpips is None:
        warnings.warn('lpips package not available; perceptual loss disabled.')
        return None
    try:
        p = lpips.LPIPS(net='alex')
        p = p.to(device)
        return p
    except Exception as e:
        warnings.warn('Could not load pretrained LPIPS network (offline?). Falling back to random init. ' + str(e))
        return lpips.LPIPS(net='alex', pnet_rand=True).to(device)

LPIPS = make_lpips(DEVICE)

bce_logits = nn.BCEWithLogitsLoss()

def distortion_loss(x, x_rec, kM=KM, kP=KP):
    mse = F.mse_loss(x_rec, x)
    lp = torch.tensor(0.0, device=x.device)
    if LPIPS is not None and kP > 0:
        try:
            lp = LPIPS((x*2-1), (x_rec*2-1)).mean()
        except Exception as e:
            warnings.warn('LPIPS computation failed: ' + str(e))
            lp = torch.tensor(0.0, device=x.device)
    return kM * mse + kP * lp, float(mse.item()), float(lp.item() if isinstance(lp, torch.Tensor) else lp)

def gan_generator_loss(d_fake_logits):
    # wants discriminator logits for fake samples
    return bce_logits(d_fake_logits, torch.ones_like(d_fake_logits))

def gan_discriminator_loss(d_real_logits, d_fake_logits):
    # d_real_logits: logits for real images, d_fake_logits: logits for fake images (detached)
    loss_real = bce_logits(d_real_logits, torch.ones_like(d_real_logits))
    loss_fake = bce_logits(d_fake_logits, torch.zeros_like(d_fake_logits))
    return loss_real + loss_fake


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: E:\Software_installation\Anaconda\anaconda3\envs\ipc\Lib\site-packages\lpips\weights\v0.1\alex.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [8]:
# Cell 8 — Training loop (AMP + accumulation, low-memory) — corrected AMP usage
from torch.amp import autocast, GradScaler

def train_loop(data_dir=TRAIN_DIR, max_steps=MAX_STEPS):
    # prepare dataloader
    if not os.path.isdir(data_dir) or len(os.listdir(data_dir))==0:
        print('No dataset found at', data_dir, "— creating demo images in 'demo_images' and using them.")
        demo = create_demo_images('demo_images', n=16, size=256)
        data_dir = demo
    dl = make_dataloader(data_dir, batch_size=BATCH_SIZE)

    # instantiate models
    E = Encoder().to(DEVICE)
    # generator expects channels BASE_CHANNELS*8
    G = Generator(n_res=N_RES).to(DEVICE)
    H_enc = HyperEncoder().to(DEVICE)
    H_dec = HyperDecoder().to(DEVICE)
    D = None
    if BETA_GAN > 0:
        D = Discriminator(cond_ch=3).to(DEVICE)

    optEG = torch.optim.Adam(list(E.parameters()) + list(G.parameters()) + list(H_enc.parameters()) + list(H_dec.parameters()), lr=LEARNING_RATE, betas=(0.9,0.999))
    optD = torch.optim.Adam(D.parameters(), lr=LEARNING_RATE, betas=(0.9,0.999)) if D is not None else None

    # Safe GradScaler init for different torch versions
    try:
        scaler = GradScaler(device=torch.device('cuda')) if DEVICE.startswith('cuda') else None
    except TypeError:
        # older torch versions may not accept device keyword
        scaler = GradScaler() if DEVICE.startswith('cuda') else None

    # Prepare autocast kwargs
    autocast_kwargs = {"device_type": "cuda"} if DEVICE.startswith('cuda') else {"device_type": "cpu"}

    step = 0
    while step < max_steps:
        for batch in dl:
            imgs, _ = batch
            imgs = imgs.to(DEVICE)

            # Use modern autocast
            with autocast(**autocast_kwargs):
                y = E(imgs)
                y_tilde = add_uniform_noise(y)
                z = H_enc(y_tilde)
                z_q = round_st(z)
                hyper = H_dec(z_q)
                C_y = y_tilde.shape[1]
                if hyper.shape[1] >= C_y*2:
                    mu = hyper[:, :C_y, :, :]
                    sigma = F.softplus(hyper[:, C_y:C_y*2, :, :]) + 1e-6
                else:
                    mu = torch.zeros_like(y_tilde)
                    sigma = torch.ones_like(y_tilde)
                y_q = round_st(y)
                x_rec = G(y_q)
                # discriminator conditioning (simple)
                y_up = F.interpolate(y_q, size=imgs.shape[2:], mode='nearest')
                if y_up.shape[1] >= 3:
                    y_cond = y_up[:, :3, :, :].detach()
                else:
                    y_cond = y_up.repeat(1, 3//y_up.shape[1] + 1, 1, 1)[:, :3, :, :].detach()

                # discriminator step (optional)
                if D is not None:
                    d_real = D(imgs, y_cond)
                    d_fake = D(x_rec.detach(), y_cond)
                    lossD = gan_discriminator_loss(d_real, d_fake)
                else:
                    lossD = None

                # generator & encoder loss
                if D is not None:
                    d_fake_forG = D(x_rec, y_cond)
                    gen_gan = gan_generator_loss(d_fake_forG)
                else:
                    gen_gan = 0.0

                d_loss, mse_val, lpips_val = distortion_loss(imgs, x_rec, kM=KM, kP=KP)
                bits = estimate_bits(y_q, mu, sigma)
                rate = bits / (imgs.shape[0] * imgs.shape[2] * imgs.shape[3])
                lossG = LAMBDA_RATE * rate + d_loss + BETA_GAN * gen_gan

            # backprop (with scaler if using AMP)
            if scaler is not None:
                scaler.scale(lossG).backward()
            else:
                lossG.backward()

            # step optimizers with accumulation
            if (step + 1) % ACCUM_STEPS == 0:
                if scaler is not None:
                    # step generator/encoder optimizer
                    scaler.step(optEG)
                    scaler.update()
                else:
                    optEG.step()
                optEG.zero_grad()

                # Note: discriminator step was computed earlier; if you enable GAN training (BETA_GAN>0)
                # you should add explicit backward/step for optD here (similar to optEG).
                if D is not None and lossD is not None:
                    # Basic D update (unscaled if no scaler, or scaled if using scaler)
                    if scaler is not None:
                        scaler.scale(lossD).backward()
                        scaler.step(optD)
                        scaler.update()
                    else:
                        lossD.backward()
                        optD.step()
                    optD.zero_grad()

            if step % 50 == 0:
                # guard float conversion in case lossG is a tensor on CUDA
                print(f'Step {step} | lossG {float(lossG):.6f} | rate bpp {float(rate):.6f} | mse {mse_val:.6f} | lpips {lpips_val:.6f}')
            if step % SAVE_EVERY == 0 and step>0:
                ckpt = {'E':E.state_dict(), 'G':G.state_dict(), 'H_enc':H_enc.state_dict(), 'H_dec':H_dec.state_dict(), 'step': step}
                torch.save(ckpt, os.path.join(CHECKPOINT_DIR, f'hific_lowmem_ckpt_{step}.pt'))
                print('Saved checkpoint', step)
            step += 1
            if step >= max_steps:
                break
    return {'E':E, 'G':G, 'H_enc':H_enc, 'H_dec':H_dec}


In [9]:
# Cell 9 — Evaluation helper (PSNR and optional LPIPS)
from skimage.metrics import peak_signal_noise_ratio as sk_psnr
def save_image_tensor(tensor, filename):
    utils.save_image(tensor.clamp(0,1), filename)

def psnr_np(a, b, data_range=1.0):
    mse = np.mean((a - b) ** 2)
    if mse == 0:
        return 100.0
    return 10 * math.log10((data_range ** 2) / mse)

def evaluate_and_save(models, data_dir=TEST_DIR, out_dir=OUT_DIR, n_images=8):
    E, G = models['E'], models['G']
    E.eval(); G.eval()
    if not os.path.isdir(data_dir) or len(os.listdir(data_dir))==0:
        data_dir = create_demo_images('demo_images_eval', n=8, size=256)
    dl = make_dataloader(data_dir, batch_size=1, crop=CROP, shuffle=False)
    os.makedirs(out_dir, exist_ok=True)
    psnrs = []
    lpips_vals = []
    cnt = 0
    with torch.no_grad():
        for (img,_) in dl:
            img = img.to(DEVICE)
            y = E(img)
            y_q = round_st(y)
            rec = G(y_q)
            a = img[0].cpu().numpy().transpose(1,2,0)
            b = rec[0].cpu().numpy().transpose(1,2,0)
            ps = psnr_np(a, b)
            psnrs.append(ps)
            if LPIPS is not None:
                try:
                    lp = LPIPS((img*2-1), (rec*2-1)).item()
                except Exception:
                    lp = 0.0
            else:
                lp = 0.0
            lpips_vals.append(lp)
            save_image_tensor(torch.cat([img, rec], dim=0), f"{out_dir}/recon_{cnt:03d}.png")
            cnt += 1
            if cnt >= n_images:
                break
    print(f"Avg PSNR: {np.mean(psnrs):.3f} dB | Avg LPIPS: {np.mean(lpips_vals):.6f}")
    return psnrs, lpips_vals


In [10]:
# # Cell 10 — Quick demo run (small). Run this cell to test the pipeline.
# if __name__ == '__main__':
#     print('Starting small demo run...')
#     models = train_loop(max_steps=40050)
    


Starting small demo run...
Step 0 | lossG 0.172354 | rate bpp 5.137017 | mse 0.051800 | lpips 0.000000
Step 50 | lossG 0.156578 | rate bpp 4.287329 | mse 0.033103 | lpips 0.000000
Step 100 | lossG 0.142381 | rate bpp 4.126629 | mse 0.040217 | lpips 0.000000
Step 150 | lossG 0.150415 | rate bpp 4.225074 | mse 0.038832 | lpips 0.000000
Step 200 | lossG 0.168539 | rate bpp 4.354912 | mse 0.051490 | lpips 0.000000
Step 250 | lossG 0.152330 | rate bpp 4.298512 | mse 0.045510 | lpips 0.000000
Step 300 | lossG 0.142002 | rate bpp 4.297060 | mse 0.037156 | lpips 0.000000
Step 350 | lossG 0.161624 | rate bpp 4.280373 | mse 0.035398 | lpips 0.000000
Step 400 | lossG 0.190848 | rate bpp 4.624404 | mse 0.078942 | lpips 0.000000
Step 450 | lossG 0.155369 | rate bpp 4.324577 | mse 0.050611 | lpips 0.000000
Step 500 | lossG 0.155060 | rate bpp 4.174443 | mse 0.046105 | lpips 0.000000
Step 550 | lossG 0.151983 | rate bpp 4.351433 | mse 0.035353 | lpips 0.000000
Step 600 | lossG 0.137539 | rate bpp 4.2

In [20]:
evaluate_and_save(models, n_images=10)

Avg PSNR: 21.409 dB | Avg LPIPS: 0.271036


([22.16738169910996,
  21.67873405061355,
  20.10246828911869,
  22.32617131945135,
  19.234802182771325,
  25.121291969351915,
  24.402964387472856,
  19.32816729103061,
  21.018302995173737,
  18.706387094401478],
 [0.32690465450286865,
  0.24369613826274872,
  0.21585050225257874,
  0.2266412228345871,
  0.40402546525001526,
  0.22896550595760345,
  0.281113862991333,
  0.23242506384849548,
  0.33271151781082153,
  0.21802504360675812])