# Implementation of NCSN (Noise Conditional Score Networks)

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Tue Oct 22 19:49:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   44C    P0              43W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  | 00000004:05:00.0 Off |  

In [2]:
try:
    from NCSN.train import train, make_dataset, cal_noise_level, langevin
    from NCSN.utils import train_set, val_set
    from NCSN.model import CondUNet, UNetv2
except:
    from train import train, make_dataset, cal_noise_level, langevin
    from utils import train_set, val_set
    from model import CondUNet, UNetv2

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import datetime
import os

train_loader = DataLoader(train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here are the hyperparameters

epochs = 200
criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 1e-5
T = 5
eval_freq = 5
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)
model = UNetv2(sigmas=torch.tensor(sigmas))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("model:", model)
# calculate the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print("number of parameters:", num_params)


time_str = timestr()
print("time string:", time_str)

train(epochs=epochs, model=model, optimizer=optimizer, criterion=criterion, train_loader=train_loader, val_loader=val_loader, sigmas=sigmas, eps=eps, T=T, time_str=time_str, eval_freq=eval_freq, ema_decay=ema_decay)

model: UNetv2(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_up2): Conv2d(32, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_up3): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (up1): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))
  (up2): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
number of parameters: 264097
time string: 20241022_182843
code files copied


Epoch 2/200:   1%|▍                                                     | 2/235 [00:00<00:20, 11.39it/s, Train Loss=1.5]

Epoch 0: train loss 0.859027462817253


Epoch 3/200:   1%|▍                                                    | 2/235 [00:00<00:20, 11.39it/s, Train Loss=1.24]

Epoch 1: train loss 0.6839413024009542


Epoch 4/200:   1%|▍                                                    | 2/235 [00:00<00:20, 11.34it/s, Train Loss=1.17]

Epoch 2: train loss 0.5908247067573222


Epoch 5/200:   1%|▍                                                    | 2/235 [00:00<00:20, 11.37it/s, Train Loss=1.01]

Epoch 3: train loss 0.5324361089696276


                                                                                                                        

Epoch 4: train loss 0.4908827690368003


Corruption MSE: 0.063251, Recovered MSE: 22.471667: 100%|██████████████████████████| 2000/2000 [00:07<00:00, 269.83it/s]
Epoch 6/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.962]

Current best MSE: inf -> 22.4716669921875


Epoch 7/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.859]

Epoch 5: train loss 0.45820495346759227


Epoch 8/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.843]

Epoch 6: train loss 0.4276010959706408


Epoch 9/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.32it/s, Train Loss=0.765]

Epoch 7: train loss 0.4025224393986641


Epoch 10/200:   1%|▍                                                  | 2/235 [00:00<00:21, 11.04it/s, Train Loss=0.693]

Epoch 8: train loss 0.3826751970230265


                                                                                                                        

Epoch 9: train loss 0.3634591368918723


Corruption MSE: 0.063603, Recovered MSE: 0.865407: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 269.57it/s]
Epoch 11/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.61it/s, Train Loss=0.718]

Current best MSE: 22.4716669921875 -> 0.8654071044921875


Epoch 12/200:   1%|▍                                                  | 2/235 [00:00<00:21, 11.07it/s, Train Loss=0.658]

Epoch 10: train loss 0.345896104295203


Epoch 13/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.692]

Epoch 11: train loss 0.3283837616443634


Epoch 14/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.586]

Epoch 12: train loss 0.3172486161932032


Epoch 15/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.24it/s, Train Loss=0.637]

Epoch 13: train loss 0.306239194565631


                                                                                                                        

Epoch 14: train loss 0.2967737380494463


Corruption MSE: 0.063277, Recovered MSE: 0.070385: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.15it/s]
Epoch 16/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.61it/s, Train Loss=0.599]

Current best MSE: 0.8654071044921875 -> 0.0703850040435791


Epoch 17/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.577]

Epoch 15: train loss 0.2929829736339285


Epoch 18/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.542]

Epoch 16: train loss 0.28272727398162195


Epoch 19/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.555]

Epoch 17: train loss 0.2771655801128834


Epoch 20/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.26it/s, Train Loss=0.549]

Epoch 18: train loss 0.27051367835795626


                                                                                                                        

Epoch 19: train loss 0.26124767122116493


Corruption MSE: 0.063627, Recovered MSE: 0.034632: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.32it/s]
Epoch 21/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.60it/s, Train Loss=0.516]

Current best MSE: 0.0703850040435791 -> 0.03463201999664307


Epoch 22/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.15it/s, Train Loss=0.518]

Epoch 20: train loss 0.2571842821354562


Epoch 23/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.483]

Epoch 21: train loss 0.25072640351792597


Epoch 24/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.499]

Epoch 22: train loss 0.24482885966909693


Epoch 25/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.497]

Epoch 23: train loss 0.23908180550057836


                                                                                                                        

Epoch 24: train loss 0.23416482736455632


Corruption MSE: 0.063745, Recovered MSE: 0.034992: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.37it/s]
Epoch 27/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.451]

Epoch 25: train loss 0.22934356598143882


Epoch 28/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.27it/s, Train Loss=0.441]

Epoch 26: train loss 0.22452552014208854


Epoch 29/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.432]

Epoch 27: train loss 0.22066808495115728


Epoch 30/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.415]

Epoch 28: train loss 0.21666624723596775


                                                                                                                        

Epoch 29: train loss 0.21123694067305707


Corruption MSE: 0.062824, Recovered MSE: 0.034241: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.45it/s]
Epoch 31/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.58it/s, Train Loss=0.393]

Current best MSE: 0.03463201999664307 -> 0.03424135971069336


Epoch 32/200:   1%|▍                                                  | 2/235 [00:00<00:21, 11.09it/s, Train Loss=0.409]

Epoch 30: train loss 0.20714781239945837


Epoch 33/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.404]

Epoch 31: train loss 0.2058558814069058


Epoch 34/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.41]

Epoch 32: train loss 0.20111428930404338


Epoch 35/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.407]

Epoch 33: train loss 0.19937373031961156


                                                                                                                        

Epoch 34: train loss 0.19573026711636402


Corruption MSE: 0.063299, Recovered MSE: 0.031557: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.33it/s]
Epoch 36/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.59it/s, Train Loss=0.404]

Current best MSE: 0.03424135971069336 -> 0.03155719995498657


Epoch 37/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.378]

Epoch 35: train loss 0.1922135018287821


Epoch 38/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.367]

Epoch 36: train loss 0.18997571690285459


Epoch 39/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.373]

Epoch 37: train loss 0.18739591811565642


Epoch 40/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.364]

Epoch 38: train loss 0.1847243163179844


                                                                                                                        

Epoch 39: train loss 0.1817318353881227


Corruption MSE: 0.063168, Recovered MSE: 0.027131: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.30it/s]
Epoch 41/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.58it/s, Train Loss=0.351]

Current best MSE: 0.03155719995498657 -> 0.027131324291229247


Epoch 42/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.36]

Epoch 40: train loss 0.18009236049144825


Epoch 43/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.24it/s, Train Loss=0.382]

Epoch 41: train loss 0.1771092082591767


Epoch 44/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.328]

Epoch 42: train loss 0.17494234816825135


Epoch 45/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.40it/s, Train Loss=0.336]

Epoch 43: train loss 0.17246685345122154


                                                                                                                        

Epoch 44: train loss 0.17206198604817086


Corruption MSE: 0.063075, Recovered MSE: 0.023604: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.12it/s]
Epoch 46/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.53it/s, Train Loss=0.34]

Current best MSE: 0.027131324291229247 -> 0.02360369825363159


Epoch 47/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.35]

Epoch 45: train loss 0.17090274697922644


Epoch 48/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.19it/s, Train Loss=0.336]

Epoch 46: train loss 0.16750357563191273


Epoch 49/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.27it/s, Train Loss=0.328]

Epoch 47: train loss 0.16673955777858165


Epoch 50/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.33it/s, Train Loss=0.336]

Epoch 48: train loss 0.16450463476333213


                                                                                                                        

Epoch 49: train loss 0.16458004110671104


Corruption MSE: 0.063258, Recovered MSE: 0.021561: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.02it/s]
Epoch 51/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.60it/s, Train Loss=0.326]

Current best MSE: 0.02360369825363159 -> 0.021560906410217285


Epoch 52/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.317]

Epoch 50: train loss 0.16166934611949516


Epoch 53/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.32]

Epoch 51: train loss 0.16132931582471158


Epoch 54/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.313]

Epoch 52: train loss 0.15922782706453445


Epoch 55/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.319]

Epoch 53: train loss 0.1580064504070485


                                                                                                                        

Epoch 54: train loss 0.15764250070490735


Corruption MSE: 0.062952, Recovered MSE: 0.020424: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.09it/s]
Epoch 56/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.31]

Current best MSE: 0.021560906410217285 -> 0.02042408227920532


Epoch 57/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.293]

Epoch 55: train loss 0.15566779895031707


Epoch 58/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.327]

Epoch 56: train loss 0.1537410593413292


Epoch 59/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.329]

Epoch 57: train loss 0.15403796599266376


Epoch 60/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.304]

Epoch 58: train loss 0.15306366396711227


                                                                                                                        

Epoch 59: train loss 0.15242252590808464


Corruption MSE: 0.063691, Recovered MSE: 0.020379: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.48it/s]
Epoch 61/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.62it/s, Train Loss=0.288]

Current best MSE: 0.02042408227920532 -> 0.02037925577163696


Epoch 62/200:   1%|▍                                                  | 2/235 [00:00<00:21, 11.09it/s, Train Loss=0.295]

Epoch 60: train loss 0.15145528303815964


Epoch 63/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.316]

Epoch 61: train loss 0.14881277249214497


Epoch 64/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.282]

Epoch 62: train loss 0.1481739417352575


Epoch 65/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.31it/s, Train Loss=0.293]

Epoch 63: train loss 0.1466607097298541


                                                                                                                        

Epoch 64: train loss 0.14602479465464327


Corruption MSE: 0.063321, Recovered MSE: 0.019797: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.06it/s]
Epoch 66/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.277]

Current best MSE: 0.02037925577163696 -> 0.01979719352722168


Epoch 67/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.292]

Epoch 65: train loss 0.14669729325365513


Epoch 68/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.285]

Epoch 66: train loss 0.1448725940699273


Epoch 69/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.281]

Epoch 67: train loss 0.14522760225103257


Epoch 70/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.272]

Epoch 68: train loss 0.14377667618558762


                                                                                                                        

Epoch 69: train loss 0.1427923561093655


Corruption MSE: 0.063547, Recovered MSE: 0.020440: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 269.73it/s]
Epoch 72/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.29]

Epoch 70: train loss 0.14237501126654606


Epoch 73/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.285]

Epoch 71: train loss 0.14177614834080352


Epoch 74/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.278]

Epoch 72: train loss 0.13996847363862586


Epoch 75/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.288]

Epoch 73: train loss 0.14060537558286748


                                                                                                                        

Epoch 74: train loss 0.13983499303143077


Corruption MSE: 0.063768, Recovered MSE: 0.020772: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.37it/s]
Epoch 77/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.258]

Epoch 75: train loss 0.13999576400569144


Epoch 78/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.276]

Epoch 76: train loss 0.13843538957707424


Epoch 79/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.279]

Epoch 77: train loss 0.1388238325398019


Epoch 80/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.279]

Epoch 78: train loss 0.13736441167111094


                                                                                                                        

Epoch 79: train loss 0.1371757280636341


Corruption MSE: 0.063230, Recovered MSE: 0.020564: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.73it/s]
Epoch 82/200:   1%|▍                                                  | 2/235 [00:00<00:21, 11.06it/s, Train Loss=0.258]

Epoch 80: train loss 0.13650093921955594


Epoch 83/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.287]

Epoch 81: train loss 0.13538390141218268


Epoch 84/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.271]

Epoch 82: train loss 0.13597636631828675


Epoch 85/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.328]

Epoch 83: train loss 0.13545020869437685


                                                                                                                        

Epoch 84: train loss 0.13528253905316617


Corruption MSE: 0.063569, Recovered MSE: 0.020712: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 269.44it/s]
Epoch 87/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.34it/s, Train Loss=0.268]

Epoch 85: train loss 0.134137062569882


Epoch 88/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.283]

Epoch 86: train loss 0.13350899305115355


Epoch 89/200:   1%|▍                                                    | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.3]

Epoch 87: train loss 0.13362099344426012


Epoch 90/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.246]

Epoch 88: train loss 0.13294790219119254


                                                                                                                        

Epoch 89: train loss 0.1320389649018328


Corruption MSE: 0.063121, Recovered MSE: 0.020649: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.31it/s]
Epoch 92/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.29]

Epoch 90: train loss 0.13118743550904255


Epoch 93/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.26it/s, Train Loss=0.265]

Epoch 91: train loss 0.13155740553394277


Epoch 94/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.277]

Epoch 92: train loss 0.13096307053210887


Epoch 95/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.258]

Epoch 93: train loss 0.13098567721057444


                                                                                                                        

Epoch 94: train loss 0.13068048126519993


Corruption MSE: 0.063386, Recovered MSE: 0.020944: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.31it/s]
Epoch 97/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.267]

Epoch 95: train loss 0.13046671028466936


Epoch 98/200:   1%|▍                                                   | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.25]

Epoch 96: train loss 0.1297378901154437


Epoch 99/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.264]

Epoch 97: train loss 0.1301741768388038


Epoch 100/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.257]

Epoch 98: train loss 0.12854104073757822


                                                                                                                        

Epoch 99: train loss 0.12855251152464683


Corruption MSE: 0.063495, Recovered MSE: 0.020649: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.27it/s]
Epoch 102/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.26]

Epoch 100: train loss 0.12779595953352907


Epoch 103/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.269]

Epoch 101: train loss 0.12819080184748832


Epoch 104/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.263]

Epoch 102: train loss 0.12765642001907876


Epoch 105/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.259]

Epoch 103: train loss 0.1283897728361982


                                                                                                                        

Epoch 104: train loss 0.12686093570070064


Corruption MSE: 0.063317, Recovered MSE: 0.020396: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.24it/s]
Epoch 107/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.263]

Epoch 105: train loss 0.12836695313453675


Epoch 108/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.25]

Epoch 106: train loss 0.12673480770689377


Epoch 109/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.34it/s, Train Loss=0.252]

Epoch 107: train loss 0.12573092915910355


Epoch 110/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.263]

Epoch 108: train loss 0.1264291505864326


                                                                                                                        

Epoch 109: train loss 0.12527477348104438


Corruption MSE: 0.063683, Recovered MSE: 0.020465: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.12it/s]
Epoch 112/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.248]

Epoch 110: train loss 0.12582724779210192


Epoch 113/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.268]

Epoch 111: train loss 0.12473677644070158


Epoch 114/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.257]

Epoch 112: train loss 0.1243591070175171


Epoch 115/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.25]

Epoch 113: train loss 0.12449872829812639


                                                                                                                        

Epoch 114: train loss 0.12458135117875768


Corruption MSE: 0.063465, Recovered MSE: 0.019710: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.59it/s]
Epoch 116/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.245]

Current best MSE: 0.01979719352722168 -> 0.01970965814590454


Epoch 117/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.262]

Epoch 115: train loss 0.12382206127364585


Epoch 118/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.24]

Epoch 116: train loss 0.12336747630479487


Epoch 119/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.264]

Epoch 117: train loss 0.12332484893342283


Epoch 120/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.242]

Epoch 118: train loss 0.12329484772809009


                                                                                                                        

Epoch 119: train loss 0.12628978703250276


Corruption MSE: 0.062862, Recovered MSE: 0.019385: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.47it/s]
Epoch 121/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.62it/s, Train Loss=0.251]

Current best MSE: 0.01970965814590454 -> 0.019385410785675048


Epoch 122/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.236]

Epoch 120: train loss 0.12325161724014486


Epoch 123/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.232]

Epoch 121: train loss 0.12276738656962172


Epoch 124/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.237]

Epoch 122: train loss 0.12237586420267187


Epoch 125/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.245]

Epoch 123: train loss 0.12255880743899245


                                                                                                                        

Epoch 124: train loss 0.12240887347688066


Corruption MSE: 0.063554, Recovered MSE: 0.019197: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.36it/s]
Epoch 126/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.234]

Current best MSE: 0.019385410785675048 -> 0.019197020530700683


Epoch 127/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.254]

Epoch 125: train loss 0.12166365023623121


Epoch 128/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.253]

Epoch 126: train loss 0.12190272392744714


Epoch 129/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.241]

Epoch 127: train loss 0.12159345254619071


Epoch 130/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.243]

Epoch 128: train loss 0.12250698024922228


                                                                                                                        

Epoch 129: train loss 0.12044163118017481


Corruption MSE: 0.063558, Recovered MSE: 0.018838: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.02it/s]
Epoch 131/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.253]

Current best MSE: 0.019197020530700683 -> 0.018837801933288573


Epoch 132/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.231]

Epoch 130: train loss 0.12105234772601026


Epoch 133/200:   1%|▍                                                 | 2/235 [00:00<00:21, 11.06it/s, Train Loss=0.242]

Epoch 131: train loss 0.12180107901705073


Epoch 134/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.242]

Epoch 132: train loss 0.12073918130803615


Epoch 135/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.232]

Epoch 133: train loss 0.12050531211685626


                                                                                                                        

Epoch 134: train loss 0.12030074025722261


Corruption MSE: 0.063960, Recovered MSE: 0.018549: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.27it/s]
Epoch 136/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.62it/s, Train Loss=0.237]

Current best MSE: 0.018837801933288573 -> 0.01854940128326416


Epoch 137/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.227]

Epoch 135: train loss 0.11948839074119609


Epoch 138/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.30it/s, Train Loss=0.233]

Epoch 136: train loss 0.11893501630488862


Epoch 139/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.231]

Epoch 137: train loss 0.11908819723002453


Epoch 140/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.226]

Epoch 138: train loss 0.11895119040570361


                                                                                                                        

Epoch 139: train loss 0.11891442286841412


Corruption MSE: 0.063700, Recovered MSE: 0.018198: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.29it/s]
Epoch 141/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.62it/s, Train Loss=0.245]

Current best MSE: 0.01854940128326416 -> 0.018197588443756102


Epoch 142/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.237]

Epoch 140: train loss 0.1188614805962177


Epoch 143/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.258]

Epoch 141: train loss 0.11822997537699152


Epoch 144/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.235]

Epoch 142: train loss 0.11856533399921783


Epoch 145/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.234]

Epoch 143: train loss 0.11862746477127076


                                                                                                                        

Epoch 144: train loss 0.11870580477283356


Corruption MSE: 0.063334, Recovered MSE: 0.017587: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.08it/s]
Epoch 146/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.62it/s, Train Loss=0.238]

Current best MSE: 0.018197588443756102 -> 0.017587113857269288


Epoch 147/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.234]

Epoch 145: train loss 0.1174482895338789


Epoch 148/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.243]

Epoch 146: train loss 0.11736462480844335


Epoch 149/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.233]

Epoch 147: train loss 0.11738282794013936


Epoch 150/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.252]

Epoch 148: train loss 0.11627369754491969


                                                                                                                        

Epoch 149: train loss 0.11648311925695297


Corruption MSE: 0.063506, Recovered MSE: 0.017462: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.04it/s]
Epoch 151/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.231]

Current best MSE: 0.017587113857269288 -> 0.017461986541748048


Epoch 152/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.236]

Epoch 150: train loss 0.11691089678317942


Epoch 153/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.231]

Epoch 151: train loss 0.11606646097720938


Epoch 154/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.234]

Epoch 152: train loss 0.11576821382375474


Epoch 155/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.226]

Epoch 153: train loss 0.115918168267037


                                                                                                                        

Epoch 154: train loss 0.11612759115214044


Corruption MSE: 0.063866, Recovered MSE: 0.017095: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.26it/s]
Epoch 156/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.62it/s, Train Loss=0.241]

Current best MSE: 0.017461986541748048 -> 0.017094592094421386


Epoch 157/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.225]

Epoch 155: train loss 0.11531510803293675


Epoch 158/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.228]

Epoch 156: train loss 0.11555520109039671


Epoch 159/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.232]

Epoch 157: train loss 0.11510013738211165


Epoch 160/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.238]

Epoch 158: train loss 0.11485823285072408


                                                                                                                        

Epoch 159: train loss 0.1158831166460159


Corruption MSE: 0.063526, Recovered MSE: 0.016527: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.60it/s]
Epoch 161/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.60it/s, Train Loss=0.218]

Current best MSE: 0.017094592094421386 -> 0.016527249813079833


Epoch 162/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.236]

Epoch 160: train loss 0.11459587213206798


Epoch 163/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.228]

Epoch 161: train loss 0.11431515714589585


Epoch 164/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.215]

Epoch 162: train loss 0.11393966652611469


Epoch 165/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.226]

Epoch 163: train loss 0.11424086838960648


                                                                                                                        

Epoch 164: train loss 0.1143592357001406


Corruption MSE: 0.063434, Recovered MSE: 0.016393: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.55it/s]
Epoch 166/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.223]

Current best MSE: 0.016527249813079833 -> 0.016392613887786865


Epoch 167/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.22]

Epoch 165: train loss 0.11438938087605416


Epoch 168/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.238]

Epoch 166: train loss 0.11352610271027748


Epoch 169/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.225]

Epoch 167: train loss 0.11340187718259527


Epoch 170/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.225]

Epoch 168: train loss 0.11328564692050852


                                                                                                                        

Epoch 169: train loss 0.11275709968932132


Corruption MSE: 0.063356, Recovered MSE: 0.016112: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.36it/s]
Epoch 171/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.61it/s, Train Loss=0.223]

Current best MSE: 0.016392613887786865 -> 0.01611180567741394


Epoch 172/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.219]

Epoch 170: train loss 0.1131298337845092


Epoch 173/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.223]

Epoch 171: train loss 0.11311738145478228


Epoch 174/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.221]

Epoch 172: train loss 0.11248682968794031


Epoch 175/200:   1%|▍                                                  | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.22]

Epoch 173: train loss 0.11249310256952935


                                                                                                                        

Epoch 174: train loss 0.11157135383245793


Corruption MSE: 0.062932, Recovered MSE: 0.015661: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.59it/s]
Epoch 176/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.224]

Current best MSE: 0.01611180567741394 -> 0.015660959005355836


Epoch 177/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.233]

Epoch 175: train loss 0.11206752074525711


Epoch 178/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.35it/s, Train Loss=0.229]

Epoch 176: train loss 0.11122222612512872


Epoch 179/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.215]

Epoch 177: train loss 0.11175525629139961


Epoch 180/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.231]

Epoch 178: train loss 0.11140319565509228


                                                                                                                        

Epoch 179: train loss 0.11095807793292593


Corruption MSE: 0.063600, Recovered MSE: 0.015585: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.80it/s]
Epoch 181/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.221]

Current best MSE: 0.015660959005355836 -> 0.015584585428237915


Epoch 182/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.213]

Epoch 180: train loss 0.11036612664131408


Epoch 183/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.204]

Epoch 181: train loss 0.11070944378350643


Epoch 184/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.40it/s, Train Loss=0.221]

Epoch 182: train loss 0.11009767686432981


Epoch 185/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.225]

Epoch 183: train loss 0.11082219373672567


                                                                                                                        

Epoch 184: train loss 0.11093870819883143


Corruption MSE: 0.063594, Recovered MSE: 0.015284: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 269.97it/s]
Epoch 186/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.63it/s, Train Loss=0.222]

Current best MSE: 0.015584585428237915 -> 0.015283512353897095


Epoch 187/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.216]

Epoch 185: train loss 0.11000582468002401


Epoch 188/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.227]

Epoch 186: train loss 0.10991515317495833


Epoch 189/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.218]

Epoch 187: train loss 0.10918758628850288


Epoch 190/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.36it/s, Train Loss=0.212]

Epoch 188: train loss 0.10876090063693676


                                                                                                                        

Epoch 189: train loss 0.10916099538828464


Corruption MSE: 0.063099, Recovered MSE: 0.014984: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.12it/s]
Epoch 191/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.60it/s, Train Loss=0.223]

Current best MSE: 0.015283512353897095 -> 0.01498395299911499


Epoch 192/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.216]

Epoch 190: train loss 0.1095327780601826


Epoch 193/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.37it/s, Train Loss=0.209]

Epoch 191: train loss 0.10943602158668193


Epoch 194/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.214]

Epoch 192: train loss 0.108037073307849


Epoch 195/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.39it/s, Train Loss=0.209]

Epoch 193: train loss 0.10922080792011099


                                                                                                                        

Epoch 194: train loss 0.10813406033718839


Corruption MSE: 0.062573, Recovered MSE: 0.014441: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.16it/s]
Epoch 196/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.51it/s, Train Loss=0.216]

Current best MSE: 0.01498395299911499 -> 0.014440977573394775


Epoch 197/200:   1%|▍                                                 | 2/235 [00:00<00:21, 11.07it/s, Train Loss=0.222]

Epoch 195: train loss 0.10932942552769438


Epoch 198/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.30it/s, Train Loss=0.211]

Epoch 196: train loss 0.10828428484023886


Epoch 199/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.214]

Epoch 197: train loss 0.10818495563370116


Epoch 200/200:   1%|▍                                                 | 2/235 [00:00<00:20, 11.38it/s, Train Loss=0.222]

Epoch 198: train loss 0.10855060993356908


                                                                                                                        

Epoch 199: train loss 0.10697510635599176


Corruption MSE: 0.063483, Recovered MSE: 0.014693: 100%|███████████████████████████| 2000/2000 [00:07<00:00, 270.44it/s]


In [4]:
# # load a model
# import os
# model = UNet(L=n_sigma).cuda()
# print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# eval_epoch = 70
# model.load_state_dict(torch.load(f'NCSN/models/{eval_epoch}.pth'))
# make_dataset(model, sigmas, eps=eps, T=T)
# print("Dataset created")
# os.system("python NCSN/evaluate.py")

In [1]:
# sampling
try:
    from NCSN.model import CondUNet, UNetv2
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin
    from NCSN.ema import EMAHelper

except:
    from model import CondUNet, UNetv2
    from train import evaluate_denoising, cal_noise_level, langevin
    from ema import EMAHelper

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 6e-6
T = 10
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNetv2(sigmas=torch.tensor(sigmas)).cuda()
time_str = "20241022_182843"
eval_epoch = 199

if ema_decay is not None:
    ema_helper = EMAHelper(ema_decay)
    ema_helper.register(model)
    ema_helper.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}_ema.pth"))
    ema_helper.ema(model)
else:
    model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.cuda()

# x = torch.rand(10, 1, 28, 28).cuda()
x = torch.randn(10, 1, 28, 28).cuda()
x = (x+1)/2
y = langevin(model, x, sigmas, eps=eps, T=T, save=True, epochs=eval_epoch, clamp=False, time_str=time_str, verbose=True)

Files already downloaded and verified
Files already downloaded and verified
level: 0, step_size: 47.03999999999986, grad_norm: 1.195145845413208, image_norm: 191.0624237060547, snr: 0.295379102230072, grad_mean_norm: 275.4219665527344
level: 0, step_size: 47.03999999999986, grad_norm: 0.9893203973770142, image_norm: 251.9375457763672, snr: 0.24224522709846497, grad_mean_norm: 82.0451889038086
level: 0, step_size: 47.03999999999986, grad_norm: 0.97486811876297, image_norm: 298.11053466796875, snr: 0.23741193115711212, grad_mean_norm: 85.88428497314453
level: 0, step_size: 47.03999999999986, grad_norm: 0.9631657004356384, image_norm: 333.8455810546875, snr: 0.23982571065425873, grad_mean_norm: 81.64767456054688
level: 0, step_size: 47.03999999999986, grad_norm: 0.9543590545654297, image_norm: 367.0375671386719, snr: 0.23320022225379944, grad_mean_norm: 80.03102111816406
level: 0, step_size: 47.03999999999986, grad_norm: 0.9616582989692688, image_norm: 393.6832580566406, snr: 0.2343251705

In [1]:
# visualize denoising
try:
    from NCSN.utils import val_set
    from NCSN.model import CondUNet, UNetv2
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin
    from NCSN.ema import EMAHelper

except:
    from utils import val_set
    from model import CondUNet, UNetv2
    from train import evaluate_denoising, cal_noise_level, langevin
    from ema import EMAHelper

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

val_loader = DataLoader(val_set, 10, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 6e-6
T = 10
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNetv2(sigmas=torch.tensor(sigmas)).cuda()
time_str = "20241022_182843"
eval_epoch = 199

if ema_decay is not None:
    ema_helper = EMAHelper(ema_decay)
    ema_helper.register(model)
    ema_helper.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}_ema.pth"))
    ema_helper.ema(model)
else:
    model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.cuda()

corruption_mse, mse, original, broken, recovered = evaluate_denoising(
    model, sigmas, eps=eps, T=T, val_loader=val_loader, outdir=f'NCSN/rubbish.out', visualize=True
)

Files already downloaded and verified
Files already downloaded and verified


Eval:   0%|                                                                                    | 0/2000 [00:00<?, ?it/s]

level: 0, step_size: 47.03999999999986, grad_norm: 0.9159242510795593, image_norm: 136.65994262695312, snr: 0.22355157136917114, grad_mean_norm: 243.8917236328125
level: 0, step_size: 47.03999999999986, grad_norm: 1.0196850299835205, image_norm: 177.761962890625, snr: 0.24864326417446136, grad_mean_norm: 85.75312805175781
level: 0, step_size: 47.03999999999986, grad_norm: 1.0420478582382202, image_norm: 203.01678466796875, snr: 0.2536088824272156, grad_mean_norm: 90.26822662353516
level: 0, step_size: 47.03999999999986, grad_norm: 1.040546178817749, image_norm: 223.3791046142578, snr: 0.255245178937912, grad_mean_norm: 88.83049011230469
level: 0, step_size: 47.03999999999986, grad_norm: 1.0552915334701538, image_norm: 241.1633758544922, snr: 0.2593735456466675, grad_mean_norm: 90.06800842285156
level: 0, step_size: 47.03999999999986, grad_norm: 1.046255350112915, image_norm: 255.1326141357422, snr: 0.2589056193828583, grad_mean_norm: 82.48102569580078
level: 0, step_size: 47.0399999999

Corruption MSE: 0.061749, Recovered MSE: 0.010772:   0%|▏                             | 10/2000 [00:02<07:04,  4.69it/s]

level: 73, step_size: 7.435628505275253e-06, grad_norm: 1756.0316162109375, image_norm: 8.873946189880371, snr: 0.17186549305915833, grad_mean_norm: 40.148109436035156
level: 74, step_size: 6e-06, grad_norm: 1948.3406982421875, image_norm: 8.873139381408691, snr: 0.16766342520713806, grad_mean_norm: 39.099159240722656
level: 74, step_size: 6e-06, grad_norm: 1938.2613525390625, image_norm: 8.872430801391602, snr: 0.16760388016700745, grad_mean_norm: 37.262847900390625
level: 74, step_size: 6e-06, grad_norm: 1922.4471435546875, image_norm: 8.873458862304688, snr: 0.16932858526706696, grad_mean_norm: 36.89952850341797
level: 74, step_size: 6e-06, grad_norm: 1923.2427978515625, image_norm: 8.873641014099121, snr: 0.16934734582901, grad_mean_norm: 38.14802932739258
level: 74, step_size: 6e-06, grad_norm: 1911.0859375, image_norm: 8.873896598815918, snr: 0.1662289947271347, grad_mean_norm: 38.34357833862305
level: 74, step_size: 6e-06, grad_norm: 1904.0819091796875, image_norm: 8.87335109710




In [2]:
# experiment on recovering
try:
    from NCSN.utils import val_set
    from NCSN.model import CondUNet, UNetv2
    from NCSN.train import evaluate_denoising, cal_noise_level, langevin
    from NCSN.ema import EMAHelper

except:
    from utils import val_set
    from model import CondUNet, UNetv2
    from train import evaluate_denoising, cal_noise_level, langevin
    from ema import EMAHelper

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

criterion = nn.MSELoss()
init_sigma = 28
final_sigma = 0.01
n_sigma = 75
eps = 6e-6
T = 10
ema_decay = 0.999

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNetv2(sigmas=torch.tensor(sigmas)).cuda()
time_str = "20241009_193254"
eval_epoch = 10

if ema_decay is not None:
    ema_helper = EMAHelper(ema_decay)
    ema_helper.register(model)
    ema_helper.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}_ema.pth"))
    ema_helper.ema(model)
else:
    model.load_state_dict(torch.load(f"/nobackup/users/sqa24/NCSN/{time_str}/models/{eval_epoch:03d}.pth"))
model.cuda()

corruption_mse, mse, original, broken, recovered = evaluate_denoising(
    model, sigmas, eps=eps, T=T, val_loader=val_loader, outdir=f'NCSN/rubbish.out'
)

FileNotFoundError: [Errno 2] No such file or directory: '/nobackup/users/sqa24/NCSN/20241009_193254/models/010_ema.pth'