# Implementation of NCSN (Noise Conditional Score Networks)

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Sun Sep 29 15:26:41 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   41C    P0              54W / 184W |  14898MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
try:
    from NCSN.train import train, make_dataset, cal_noise_level, langevin
    from NCSN.utils import train_set, val_set
    from NCSN.model import UNet
except:
    from train import train, make_dataset, cal_noise_level, langevin
    from utils import train_set, val_set
    from model import UNet

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import datetime
import os

train_loader = DataLoader(train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here are the hyperparameters

epochs = 100
criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 6e-6
T = 5
eval_freq = 1

model = UNet(L=n_sigma)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("model:", model)
# calculate the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print("number of parameters:", num_params)

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

time_str = timestr()
print("time string:", time_str)

train(epochs=epochs, model=model, optimizer=optimizer, criterion=criterion, train_loader=train_loader, val_loader=val_loader, sigmas=sigmas, eps=eps, T=T, time_str=time_str, eval_freq=eval_freq)

                                                                                                                        

Epoch 0: train loss 1.1735482025653758


                                                                                                                        

Epoch 1: train loss 0.7068929991823562


                                                                                                                        

Epoch 2: train loss 0.6075381304355378


                                                                                                                        

Epoch 3: train loss 0.556124329059682


                                                                                                                        

Epoch 4: train loss 0.5176817952318394


                                                                                                                        

Epoch 5: train loss 0.4852951870319691


                                                                                                                        

Epoch 6: train loss 0.44922798009628945


                                                                                                                        

Epoch 7: train loss 0.42224347971855325


                                                                                                                        

Epoch 8: train loss 0.3982025449580335


                                                                                                                        

Epoch 9: train loss 0.37406652227361153


                                                                                                                        

Epoch 10: train loss 0.3518085669963918


                                                                                                                        

Epoch 11: train loss 0.33376552018713446


                                                                                                                        

Epoch 12: train loss 0.3173093057693319


                                                                                                                        

Epoch 13: train loss 0.30393104565904494


                                                                                                                        

Epoch 14: train loss 0.28836932537403515


                                                                                                                        

Epoch 15: train loss 0.2784779130144322


                                                                                                                        

Epoch 16: train loss 0.26496840683703726


                                                                                                                        

Epoch 17: train loss 0.2539548805419435


                                                                                                                        

Epoch 18: train loss 0.2448856199041326


                                                                                                                        

Epoch 19: train loss 0.23588139332355337


                                                                                                                        

Epoch 20: train loss 0.22658174925662103


                                                                                                                        

Epoch 21: train loss 0.2171549380459684


                                                                                                                        

Epoch 22: train loss 0.20926489582721222


                                                                                                                        

Epoch 23: train loss 0.1986772933538924


                                                                                                                        

Epoch 24: train loss 0.19319733482726076


                                                                                                                        

Epoch 25: train loss 0.18753301317387439


                                                                                                                        

Epoch 26: train loss 0.17868791869346132


                                                                                                                        

Epoch 27: train loss 0.17255715651715056


                                                                                                                        

Epoch 28: train loss 0.1677958939303743


                                                                                                                        

Epoch 29: train loss 0.162903623314614


                                                                                                                        

Epoch 30: train loss 0.15703342417453198


                                                                                                                        

Epoch 31: train loss 0.15350709208782684


                                                                                                                        

Epoch 32: train loss 0.14876404651936065


                                                                                                                        

Epoch 33: train loss 0.1446277068016377


                                                                                                                        

Epoch 34: train loss 0.1414785366743169


                                                                                                                        

Epoch 35: train loss 0.1378587454557419


                                                                                                                        

Epoch 36: train loss 0.13437657977672332


                                                                                                                        

Epoch 37: train loss 0.13087289720139605


                                                                                                                        

Epoch 38: train loss 0.1293339936657155


                                                                                                                        

Epoch 39: train loss 0.1264410804243798


                                                                                                                        

Epoch 40: train loss 0.12352061905759446


Epoch 42/100:  94%|██████████████████████████████████████████████▎  | 222/235 [00:24<00:01,  8.92it/s, Train Loss=0.121]

KeyboardInterrupt: 

In [4]:
# # load a model
# import os
# model = UNet(L=n_sigma).cuda()
# print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# eval_epoch = 70
# model.load_state_dict(torch.load(f'NCSN/models/{eval_epoch}.pth'))
# make_dataset(model, sigmas, eps=eps, T=T)
# print("Dataset created")
# os.system("python NCSN/evaluate.py")

In [6]:
# sampling
try:
    from NCSN.utils import val_set
    from NCSN.model import UNet
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin

except:
    from utils import val_set
    from model import UNet
    from train import evaluate_denoising, cal_noise_level, langevin

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 1
final_sigma = 0.01
n_sigma = 10
eps = 2e-5
T = 100

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNet(L=n_sigma).cuda()
time_str = "20241009_173109"
eval_epoch = 5
model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.to('cuda')

# x = torch.rand(10, 1, 28, 28).cuda()
x = torch.randn(10, 1, 28, 28).cuda()
x = (x+1)/2
y = langevin(model, x, sigmas, eps=eps, T=T, save=True, epochs=eval_epoch, clamp=False, time_str=time_str)

In [None]:
# experiment on recovering
try:
    from NCSN.utils import val_set
    from NCSN.model import UNet
    from NCSN.train import evaluate_denoising, cal_noise_level  

except:
    from utils import val_set
    from model import UNet
    from train import evaluate_denoising, cal_noise_level

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 1
final_sigma = 0.01
n_sigma = 10
eps = 8e-5
T = 20

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNet(L=n_sigma).cuda()
time_str = "20241009_173109"
eval_epoch = 5
model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.to('cuda')

corruption_mse, mse, original, broken, recovered = evaluate_denoising(
    model, sigmas, eps=eps, T=T, val_loader=val_loader, outdir=f'NCSN/rubbish.out'
)