## SR

Relevant chats:
* https://chatgpt.com/share/68c86cf9-29f4-800b-a073-43a69e79b17f
* https://chatgpt.com/share/68c7061b-5bdc-800b-9821-48fd1625ed65
* https://chatgpt.com/share/68c6111a-55d0-800b-b1b3-14e3cf3ff731

In [1]:
import IPython.display as ipd
import torch
import torch.nn as nn
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from PIL import Image

from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device='cuda'

In [2]:
from elucidated_diffusion.image_helpers import sr_to_pil
from elucidated_diffusion.image_helpers import pil_to_data_url
from elucidated_diffusion.image_helpers import html_for_images
from elucidated_diffusion.checkpoint_helper import load_checkpoint, save_checkpoint
from elucidated_diffusion.models.chatgpt_sr_unet import UNetSR3

from elucidated_diffusion.elucidated_diffusion import edm_ancestral_sampling_for_diffusion
from elucidated_diffusion.elucidated_diffusion import edm_ancestral_sampling_for_sr
from elucidated_diffusion.elucidated_diffusion import P_mean, P_std, sigma_data, edm_loss_weight
from elucidated_diffusion.dataset_helpers import get_datasets



In [5]:
# Sanity check the model's inpus and outputs are as expected
if sanity_check_model := True:
    B = 4
    hrimg = torch.randn(B, 3, 256, 256).cuda()
    lrimg = torch.randn(B, 3, 64, 64).cuda()
    t = torch.rand(B).cuda()

    #lr_up = F.interpolate(lrimg, size=(256,256), mode='bilinear', align_corners=False)
    #x_noisy = torch.randn_like(hrimg)
    #x_in = torch.cat([x_noisy, lr_up], dim=1).cuda()

    model = UNetSR3().cuda()
    with torch.no_grad():
        out = model(x_noisy, lrimg, t)
    print("Output:", out.shape)  # should be [4, 3, 256, 256]
    num_params = sum(p.numel() for p in model.parameters())
    print(f"SRUNet number of parameters: {num_params:,}")

Output: torch.Size([4, 3, 256, 256])
SRUNet number of parameters: 3,747,907


## Ancestral Sampling.

In [6]:
if sanity_check_model := True:
    with torch.no_grad():
        edm_ancestral_sampling_for_sr(model,lrimg,batch_size=4,img_shape=(3,256,256))

In [11]:
# Modified from my diffusion model in a cell above.

from torch.amp.grad_scaler import GradScaler
scaler = GradScaler(device="cuda")
#from torch.cuda.amp import autocast
from torch.amp.autocast_mode import autocast


def train_a_batch(model_edm, optimizer_edm, batch):
 
        x = batch
        hr_256, lr_64 = x
        # Move to device
        hr_256 = hr_256.to(device)          # [B, 3, 256, 256]
        lr_64 = lr_64.to(device)            # [B, 3, 64, 64]

        x = hr_256
        # Log-normal sigma sampling (as in NVIDIA's implementation)
        rnd_normal = torch.randn([x.shape[0], 1, 1, 1], device=device)
        sigma = (rnd_normal * P_std + P_mean).exp()
        
        # Preconditioning coefficients (from NVIDIA's EDMPrecond)
        c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
        c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2).sqrt()
        c_in = 1 / (sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.log() / 4
        
        # Add noise: y_noisy = x + σ·n where n ~ N(0,I)
        noise = torch.randn_like(x) * sigma
        y_noisy = x + noise
        
        # Model prediction: F_x = net(c_in * y_noisy, c_noise)
        c_noise_input = c_noise.view(x.shape[0])


        # Upsample LR to HR resolution
        lr_up = F.interpolate(lr_64, size=hr_256.shape[-2:], mode='bilinear', align_corners=False)
        # Expected shape: [B, 3, 256, 256]

        ## Sample timestep t for each item in batch
        #b = hr_256.shape[0]
        
        # Concatenate noisy HR + upsampled LR for conditioning
        #x_input = torch.cat([y_noisy, lr_up], dim=1)  # [B, 6, 256, 256]

        # --- Optional sanity assertions ---
        #assert x_input.shape[1] == 6, f"Expected 6 channels, got {x_input.shape[1]}"
        #assert hr_256.shape == noise.shape, "HR and noise shapes must match"


        # this_models_input = torch.cat([c_in * y_noisy, lr_up], dim=1).cuda()
        # From ChatGPT code review
        # this_models_input = torch.cat([c_in * y_noisy, c_in * lr_up], dim=1).cuda()
        # TODO: I'm not sure I trust that -- it looked pretty good before this change.
        # but then in a second chatgpt code review it asked me to reverse it again
        # https://chatgpt.com/share/68c7061b-5bdc-800b-9821-48fd1625ed65
        # Double-check this after it runs a long time
        #this_models_input = torch.cat([c_in * y_noisy, lr_up], dim=1)
        
        with autocast(device_type="cuda"):
            #F_x = model_edm(this_models_input, c_noise_input)
            F_x = model_edm(c_in * y_noisy, lr_up, c_noise_input)
            # Preconditioning: D_x = c_skip * y_noisy + c_out * F_x
            D_x = c_skip * y_noisy + c_out * F_x
            # Loss weight: λ(σ) = (σ² + σ_data²) / (σ·σ_data)²
            #weight = edm_loss_weight(sigma.squeeze(), sigma_data).view(-1, 1, 1, 1)
            # 
            weight = edm_loss_weight(sigma.flatten(), sigma_data)[:, None, None, None]

            # Loss: weighted MSE between preconditioning output and clean image
            loss = (weight * (D_x - x) ** 2).mean()
        
        optimizer_edm.zero_grad()
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model_edm.parameters(), max_norm=1.0)
        scaler.step(optimizer_edm)
        scaler.update()
    
        return loss.item()


In [12]:
if sanity_check_model := True:

    model = UNetSR3().to(device)
    lr = 1e-4
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    experiment_name = 'fantasy'
    model = UNetSR3().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    lrds,paired_dataset = get_datasets('fantasy')
    train_loader = DataLoader(paired_dataset, batch_size=4, shuffle=False)
    batch = next(iter(train_loader))
    train_a_batch(model, optimizer, batch)
    del(model)
    del(optimizer)
    import gc
    gc.collect()


In [13]:
import time
from datetime import datetime
from tqdm import tqdm
import os

from elucidated_diffusion.ema_helper import EMAHelper

def train_some_sr(model_edm, optimizer_edm, paired_dataset, resume_from=None, experiment_name=None):
    train_loader = DataLoader(paired_dataset, batch_size=16, shuffle=True)
    #train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    if resume_from:
        load_checkpoint(model=model_edm, optimizer=optimizer_edm, path=resume_from)
    ema = EMAHelper(model_edm)    
    t0 = time.time()
    display_interval = 60
    next_display_time = time.time()
    checkpoint_interval = 60 * 60
    next_checkpoint_time = time.time() + checkpoint_interval
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    with open("/tmp/training_running.txt","w"):
        pass
    for epoch in range(9999999):
        if not os.path.exists("/tmp/training_running.txt"):
            break
        pbar = tqdm(train_loader)
        for x in pbar:
            if not os.path.exists("/tmp/training_running.txt"):
                break
            hr,lr = x
            #if not os.path.exists("/tmp/training_running.txt"):
            #    break
            loss = train_a_batch(model_edm,optimizer_edm,x)
            ema.update(model_edm)
            if experiment_name and (time.time() > next_checkpoint_time):
                path = save_checkpoint(model_edm, optimizer_edm, 85, 0.11, tag="_SR_"+experiment_name + "_raw_sr_tmp")
                path = save_checkpoint(ema.get_model(), optimizer_edm, 85, 0.11, tag="_SR_"+experiment_name + "_ema_sr_tmp")
                next_checkpoint_time = time.time() + checkpoint_interval
            if time.time() > next_display_time:
                with torch.no_grad():
                    next_display_time = next_display_time + 30
                    emamodel = ema.get_model()
                    out = edm_ancestral_sampling_for_sr(emamodel, lr[0:2].cuda(), num_steps=18,batch_size=2,img_shape=(3,256,256))
                    out2 = edm_ancestral_sampling_for_sr(emamodel, lr[0:2].cuda(), num_steps=18,batch_size=2,img_shape=(3,256,256), headstart_sigma=3)
                    out3 = edm_ancestral_sampling_for_sr(emamodel, lr[0:2].cuda(), num_steps=18,batch_size=2,img_shape=(3,256,256), headstart_sigma=1)
                    hri,lri,sri,sri2,sri3 = [sr_to_pil(i) for i in [hr[0],lr[0],out[0],out2[0],out3[0]]]
                    dt = time.strftime('%H:%M:%S', time.gmtime(time.time() - t0))
                    title = f"SR training at {dt} loss {loss}"
                    h = html_for_images([hri,lri,lri.resize((256,256)),sri,sri2,sri3], title=title)
                    ipd.clear_output(wait=True)
                    ipd.display(ipd.HTML(h))

            pbar.set_description(f"EDM Epoch {epoch+1} Loss: {loss:.4f}")

    if experiment_name:
        path = save_checkpoint(model_edm, optimizer_edm, 85, 0.11, tag="_SR_"+experiment_name + "_raw_sr_tmp")
        path = save_checkpoint(ema.get_model(), optimizer_edm, 85, 0.11, tag="_SR_"+experiment_name + "_ema_sr_tmp")
        
    del model_edm,optimizer_edm
    return path if experiment_name else ema

In [14]:
experiment_name = 'fantasy'
model = UNetSR3().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
batch_size = 8
train_dataset,paired_dataset = get_datasets(experiment_name)
#good_path = 'checkpoints/UNetSR3_2025-09-15_08-14-19__SR_fantasy_really_good.pth'
#good_path = 'checkpoints/UNetSR3_2025-09-15_19-38-09__SR_dragon_really_good.pth'
#good_path = 'checkpoints/UNetSR3_2025-09-15_08-14-19__SR_fantasy_really_good.pth'
#load_checkpoint(model, optimizer, good_path)
#ipd.display(sr_to_pil(next(iter(paired_dataset))[0]))

In [15]:

train_some_sr(model, optimizer,paired_dataset)

EDM Epoch 1 Loss: 0.4465:   8%|▊         | 67/801 [00:26<04:51,  2.52it/s]


KeyboardInterrupt: 

In [17]:
good_path = 'checkpoints/UNetSR3_2025-09-15_08-14-19__SR_fantasy_really_good.pth'
load_checkpoint(model, optimizer, good_path)
pass

🔄 Loaded checkpoint from checkpoints/UNetSR3_2025-09-15_08-14-19__SR_fantasy_really_good.pth
    Epoch: 85, Loss: 0.11
    Model class: UNetSR3


In [18]:
train_some_sr(model, optimizer,paired_dataset)

EDM Epoch 1 Loss: 12.2907:   3%|▎         | 21/801 [00:10<06:26,  2.02it/s]


KeyboardInterrupt: 