# Implementation of NCSN (Noise Conditional Score Networks)

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Thu Oct 24 16:50:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   43C    P0              43W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  | 00000004:05:00.0 Off |  

In [2]:
try:
    from NCSN.train import train, make_dataset, cal_noise_level, langevin
    from NCSN.utils import train_set, val_set
    from NCSN.ncsnv2 import NCSNv2
except:
    from train import train, make_dataset, cal_noise_level, langevin
    from utils import train_set, val_set
    from ncsnv2 import NCSNv2

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import datetime

train_loader = DataLoader(train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# model config
import yaml
import argparse

with open('NCSN/MNIST.yml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

config = dict2namespace(config)

In [4]:
model = NCSNv2(config).cuda()
x = torch.randn(256,1,28,28).cuda()
t = torch.zeros(256,dtype=torch.long).cuda()

In [5]:
with torch.set_grad_enabled(False):
    # model.train()
    model(x,t).shape
# model(x,t).shape

	 conv1 Time:  0.00835418701171875
	 output square mean:  tensor(0.4143, device='cuda:0')
	 conv2 Time:  0.010793685913085938
	 output square mean:  tensor(0.3564, device='cuda:0')
	 conv1 Time:  0.0005280971527099609
	 output square mean:  tensor(0.3946, device='cuda:0')
	 conv2 Time:  0.006690263748168945
	 output square mean:  tensor(0.4475, device='cuda:0')
layer1 Time:  0.8779275417327881
layer1 square mean:  tensor(1.5299, device='cuda:0')
	 conv1 Time:  0.000522613525390625
	 output square mean:  tensor(0.3431, device='cuda:0')
	 conv2 Time:  0.008485078811645508
	 output square mean:  tensor(0.2604, device='cuda:0')
	 conv1 Time:  0.0005159378051757812
	 output square mean:  tensor(0.3325, device='cuda:0')
	 conv2 Time:  0.004118442535400391
	 output square mean:  tensor(0.3642, device='cuda:0')
	 conv1 Time:  0.0005657672882080078
	 output square mean:  tensor(0.3348, device='cuda:0')
	 conv2 Time:  0.003651142120361328
	 output square mean:  tensor(0.3076, device='cuda:0')
	 

In [6]:
model(x,t).shape

	 conv1 Time:  0.0008597373962402344
	 output square mean:  tensor(0.4143, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv2 Time:  0.003847360610961914
	 output square mean:  tensor(0.3564, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv1 Time:  0.0005965232849121094
	 output square mean:  tensor(0.3946, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv2 Time:  0.006831645965576172
	 output square mean:  tensor(0.4475, device='cuda:0', grad_fn=<MeanBackward0>)
layer1 Time:  0.015146970748901367
layer1 square mean:  tensor(1.5299, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv1 Time:  0.0035626888275146484
	 output square mean:  tensor(0.3431, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv2 Time:  0.010590314865112305
	 output square mean:  tensor(0.2604, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv1 Time:  0.0005717277526855469
	 output square mean:  tensor(0.3325, device='cuda:0', grad_fn=<MeanBackward0>)
	 conv2 Time:  0.0040760040283203125
	 output square mean:  tensor(0.36

torch.Size([256, 1, 28, 28])

In [7]:
# Here are the hyperparameters

epochs = 500
criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 5e-5
T = 5
eval_freq = 5
ema_decay = 0.999
# ema_decay = None

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)
model = NCSNv2(config).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("model:", model)
# calculate the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print("number of parameters:", num_params)


time_str = timestr()
print("time string:", time_str)

# torch.backends.cudnn.benchmark = False

train(epochs=epochs, model=model, optimizer=optimizer, criterion=criterion, train_loader=train_loader, val_loader=val_loader, sigmas=sigmas, eps=eps, T=T, time_str=time_str, eval_freq=eval_freq, ema_decay=ema_decay)

Epoch 1/500:   0%|                                                                            | 0/235 [00:00<?, ?it/s]

model: NCSNv2(
  (act): ELU(alpha=1.0)
  (begin_conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (normalizer): InstanceNorm2dPlus(
    (instance_norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (end_conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (res1): ModuleList(
    (0): ResidualBlock(
      (non_linearity): ELU(alpha=1.0)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (normalize2): InstanceNorm2dPlus(
        (instance_norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (normalize1): InstanceNorm2dPlus(
        (instance_norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
    )
    (1): ResidualBlock(
      (non_linearity): ELU(alpha=1.0)
      (conv1): Conv2d(64,

Epoch 1/500:   0%|▏                                                  | 1/235 [00:10<39:31, 10.13s/it, Train Loss=1.14]

	 conv2 Time:  1.9441194534301758
	 output square mean:  tensor(0.2730, device='cuda:0')
layer4 Time:  9.932719945907593
layer4 square mean:  tensor(1.2165, device='cuda:0')
ref2 Time:  9.936400890350342
ref2 square mean:  tensor(2.9399, device='cuda:0')
ref4 Time:  9.989141464233398
ref4 square mean:  tensor(69.0862, device='cuda:0')
end_conv Time:  10.043654203414917
end_conv square mean:  tensor(0.2107, device='cuda:0')
end Time:  10.045524597167969
output square mean:  tensor(176.2602, device='cuda:0')
	 conv1 Time:  0.0008533000946044922
	 output square mean:  tensor(0.3863, device='cuda:0')
	 conv2 Time:  0.003858804702758789
	 output square mean:  tensor(0.3478, device='cuda:0')
	 conv1 Time:  0.0005068778991699219
	 output square mean:  tensor(0.4187, device='cuda:0')
	 conv2 Time:  0.0036382675170898438
	 output square mean:  tensor(0.3846, device='cuda:0')
layer1 Time:  0.013040304183959961
layer1 square mean:  tensor(86.4219, device='cuda:0')
	 conv1 Time:  0.000495433807373

KeyboardInterrupt: 

In [15]:
x = torch.randn(256,1,28,28).cuda()
t = torch.zeros(256,dtype=torch.long).cuda()
model(x,t).shape

	 conv1 Time:  0.0007953643798828125
	 conv2 Time:  0.0018312931060791016
	 conv1 Time:  0.0007104873657226562
	 conv2 Time:  0.0014810562133789062
layer1 Time:  0.003811359405517578
	 conv1 Time:  0.00069427490234375
	 conv2 Time:  0.0017116069793701172
	 conv1 Time:  0.0006601810455322266
	 conv2 Time:  0.0013823509216308594
	 conv1 Time:  0.0007190704345703125
	 conv2 Time:  0.0015227794647216797
	 conv1 Time:  0.0007143020629882812
	 conv2 Time:  0.0014760494232177734
layer3 Time:  0.010928869247436523
	 conv1 Time:  0.0007159709930419922
	 conv2 Time:  0.0014874935150146484
	 conv1 Time:  0.0007293224334716797
	 conv2 Time:  0.001512289047241211
layer4 Time:  0.014353752136230469
ref2 Time:  0.018111467361450195
ref4 Time:  0.02396082878112793
end_conv Time:  0.024707317352294922
end Time:  0.024844884872436523


torch.Size([256, 1, 28, 28])

In [13]:
for m in model.modules():
    if isinstance(m,nn.Conv2d):
        print(
            'm weight range:',m.weight.min().item(),m.weight.max().item(),
            ('m bias range:',m.bias.min().item(),m.bias.max().item()) if m.bias is not None else ''
        )

m weight range: -0.33154308795928955 0.33294522762298584 ('m bias range:', -0.3211643695831299, 0.3225468397140503)
m weight range: -0.04171708971261978 0.041503746062517166 ('m bias range:', 0.019171390682458878, 0.019171390682458878)
m weight range: -0.04193001613020897 0.041891563683748245 ('m bias range:', -0.041510991752147675, 0.040913477540016174)
m weight range: -0.041951701045036316 0.04194023087620735 ('m bias range:', -0.03898942098021507, 0.041212961077690125)
m weight range: -0.04195263236761093 0.041894663125276566 ('m bias range:', -0.041217345744371414, 0.041234761476516724)
m weight range: -0.0419234074652195 0.04190922901034355 ('m bias range:', -0.04168468341231346, 0.04003095254302025)
m weight range: -0.0419343039393425 0.04193822294473648 ('m bias range:', -0.0403323657810688, 0.0391434021294117)
m weight range: -0.041952118277549744 0.04192785546183586 ('m bias range:', -0.041015464812517166, 0.04176655411720276)
m weight range: -0.12503769993782043 0.12520158290

                                                                                                                      

In [4]:
# # load a model
# import os
# model = UNet(L=n_sigma).cuda()
# print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# eval_epoch = 70
# model.load_state_dict(torch.load(f'NCSN/models/{eval_epoch}.pth'))
# make_dataset(model, sigmas, eps=eps, T=T)
# print("Dataset created")
# os.system("python NCSN/evaluate.py")

In [23]:
# sampling
try:
    from NCSN.model import CondUNet, UNetv2
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin
    from NCSN.ema import EMAHelper

except:
    from model import CondUNet, UNetv2
    from train import evaluate_denoising, cal_noise_level, langevin
    from ema import EMAHelper

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 5e-5
T = 5
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNetv2(sigmas=torch.tensor(sigmas)).cuda()
time_str = "20241023_115901"
eval_epoch = 479

if ema_decay is not None:
    ema_helper = EMAHelper(ema_decay)
    ema_helper.register(model)
    ema_helper.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}_ema.pth"))
    ema_helper.ema(model)
else:
    model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.cuda()

# x = torch.rand(10, 1, 28, 28).cuda()
x = torch.randn(10, 1, 28, 28).cuda()
# x = (x+1)/2
y = langevin(model, x, sigmas, eps=eps, T=T, save=True, epochs=eval_epoch, clamp=False, time_str=time_str, verbose=True)

level: 0, step_size: 391.99999999999886, grad_norm: 0.9822129607200623, image_norm: 575.2122192382812, snr: 0.6990952491760254, grad_mean_norm: 84.818603515625
level: 0, step_size: 391.99999999999886, grad_norm: 1.0097943544387817, image_norm: 670.8079223632812, snr: 0.7203842401504517, grad_mean_norm: 83.17820739746094
level: 0, step_size: 391.99999999999886, grad_norm: 1.036985993385315, image_norm: 728.0260620117188, snr: 0.7392300963401794, grad_mean_norm: 87.04425811767578
level: 0, step_size: 391.99999999999886, grad_norm: 1.0620559453964233, image_norm: 754.8051147460938, snr: 0.7627385854721069, grad_mean_norm: 81.56352996826172
level: 0, step_size: 391.99999999999886, grad_norm: 1.0763520002365112, image_norm: 769.5587768554688, snr: 0.7710176110267639, grad_mean_norm: 85.94133758544922
level: 1, step_size: 316.31488828837405, grad_norm: 1.2067806720733643, image_norm: 767.9408569335938, snr: 0.7572596073150635, grad_mean_norm: 92.18448638916016
level: 1, step_size: 316.314888

In [1]:
# visualize denoising
try:
    from NCSN.utils import val_set
    from NCSN.model import CondUNet, UNetv2
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin
    from NCSN.ema import EMAHelper

except:
    from utils import val_set
    from model import CondUNet, UNetv2
    from train import evaluate_denoising, cal_noise_level, langevin
    from ema import EMAHelper

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

val_loader = DataLoader(val_set, 10, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 6e-6
T = 10
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNetv2(sigmas=torch.tensor(sigmas)).cuda()
time_str = "20241022_182843"
eval_epoch = 199

if ema_decay is not None:
    ema_helper = EMAHelper(ema_decay)
    ema_helper.register(model)
    ema_helper.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}_ema.pth"))
    ema_helper.ema(model)
else:
    model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.cuda()

corruption_mse, mse, original, broken, recovered = evaluate_denoising(
    model, sigmas, eps=eps, T=T, val_loader=val_loader, outdir=f'NCSN/rubbish.out', visualize=True
)

Files already downloaded and verified
Files already downloaded and verified


Eval:   0%|                                                                                    | 0/2000 [00:00<?, ?it/s]

level: 0, step_size: 47.03999999999986, grad_norm: 0.9159242510795593, image_norm: 136.65994262695312, snr: 0.22355157136917114, grad_mean_norm: 243.8917236328125
level: 0, step_size: 47.03999999999986, grad_norm: 1.0196850299835205, image_norm: 177.761962890625, snr: 0.24864326417446136, grad_mean_norm: 85.75312805175781
level: 0, step_size: 47.03999999999986, grad_norm: 1.0420478582382202, image_norm: 203.01678466796875, snr: 0.2536088824272156, grad_mean_norm: 90.26822662353516
level: 0, step_size: 47.03999999999986, grad_norm: 1.040546178817749, image_norm: 223.3791046142578, snr: 0.255245178937912, grad_mean_norm: 88.83049011230469
level: 0, step_size: 47.03999999999986, grad_norm: 1.0552915334701538, image_norm: 241.1633758544922, snr: 0.2593735456466675, grad_mean_norm: 90.06800842285156
level: 0, step_size: 47.03999999999986, grad_norm: 1.046255350112915, image_norm: 255.1326141357422, snr: 0.2589056193828583, grad_mean_norm: 82.48102569580078
level: 0, step_size: 47.0399999999

Corruption MSE: 0.061749, Recovered MSE: 0.010772:   0%|▏                             | 10/2000 [00:02<07:04,  4.69it/s]

level: 73, step_size: 7.435628505275253e-06, grad_norm: 1756.0316162109375, image_norm: 8.873946189880371, snr: 0.17186549305915833, grad_mean_norm: 40.148109436035156
level: 74, step_size: 6e-06, grad_norm: 1948.3406982421875, image_norm: 8.873139381408691, snr: 0.16766342520713806, grad_mean_norm: 39.099159240722656
level: 74, step_size: 6e-06, grad_norm: 1938.2613525390625, image_norm: 8.872430801391602, snr: 0.16760388016700745, grad_mean_norm: 37.262847900390625
level: 74, step_size: 6e-06, grad_norm: 1922.4471435546875, image_norm: 8.873458862304688, snr: 0.16932858526706696, grad_mean_norm: 36.89952850341797
level: 74, step_size: 6e-06, grad_norm: 1923.2427978515625, image_norm: 8.873641014099121, snr: 0.16934734582901, grad_mean_norm: 38.14802932739258
level: 74, step_size: 6e-06, grad_norm: 1911.0859375, image_norm: 8.873896598815918, snr: 0.1662289947271347, grad_mean_norm: 38.34357833862305
level: 74, step_size: 6e-06, grad_norm: 1904.0819091796875, image_norm: 8.87335109710




In [2]:
# experiment on recovering
try:
    from NCSN.utils import val_set
    from NCSN.model import CondUNet, UNetv2
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin
    from NCSN.ema import EMAHelper

except:
    from utils import val_set
    from model import CondUNet, UNetv2
    from train import evaluate_denoising, cal_noise_level, langevin
    from ema import EMAHelper

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 6e-6
T = 10
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNetv2(sigmas=torch.tensor(sigmas)).cuda()
time_str = "20241009_193254"
eval_epoch = 10

if ema_decay is not None:
    ema_helper = EMAHelper(ema_decay)
    ema_helper.register(model)
    ema_helper.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}_ema.pth"))
    ema_helper.ema(model)
else:
    model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.cuda()

corruption_mse, mse, original, broken, recovered = evaluate_denoising(
    model, sigmas, eps=eps, T=T, val_loader=val_loader, outdir=f'NCSN/rubbish.out'
)

FileNotFoundError: [Errno 2] No such file or directory: '/nobackup/users/sqa24/NCSN/20241009_193254/models/010_ema.pth'