# Implementation of NCSN (Noise Conditional Score Networks)

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Sun Sep 29 15:46:38 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   38C    P0              39W / 184W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
try:
    from NCSN.train import train, make_dataset, cal_noise_level, langevin
    from NCSN.utils import train_set, val_set
    from NCSN.model import UNet
except:
    from train import train, make_dataset, cal_noise_level, langevin
    from utils import train_set, val_set
    from model import UNet

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import datetime
import os

train_loader = DataLoader(train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"NCSN/training_data/{time_str}.out"
    return outdir

def get_sample_dir(time_str):
    dir = f"NCSN/samples/{time_str}/"
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
    return dir

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here are the hyperparameters

epochs = 200
criterion = nn.MSELoss()
init_sigma = 1
final_sigma = 0.01
n_sigma = 10
eps = 2e-5
T = 100

model = UNet(L=n_sigma)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

time_str = timestr()

train(epochs=epochs, model=model, optimizer=optimizer, criterion=criterion, train_loader=train_loader, val_loader=val_loader, sigmas=sigmas, eps=eps, T=T, outdir=get_outdir(time_str), eval_freq=1, sample_dir=get_sample_dir(time_str))

                                                                                                                        

Epoch 0: train loss 0.8027304471807277


                                                                                                                        

Epoch 1: train loss 0.5942564124756671


                                                                                                                        

Epoch 2: train loss 0.5071269518517433


                                                                                                                        

Epoch 3: train loss 0.44996082858836395


                                                                                                                        

Epoch 4: train loss 0.40960474965420174


                                                                                                                        

Epoch 5: train loss 0.3774731269542207


                                                                                                                        

Epoch 6: train loss 0.3484503446741307


                                                                                                                        

Epoch 7: train loss 0.32572210240871347


                                                                                                                        

Epoch 8: train loss 0.30425557849255014


                                                                                                                        

Epoch 9: train loss 0.28651224875703774


                                                                                                                        

Epoch 10: train loss 0.27154455742937456


                                                                                                                        

Epoch 11: train loss 0.25775412986887264


                                                                                                                        

Epoch 12: train loss 0.2461381214096191


                                                                                                                        

Epoch 13: train loss 0.23549187931608648


                                                                                                                        

Epoch 14: train loss 0.2243784614699952


                                                                                                                        

Epoch 15: train loss 0.21623099363864737


                                                                                                                        

Epoch 16: train loss 0.20894186002142887


                                                                                                                        

Epoch 17: train loss 0.20024394335898948


                                                                                                                        

Epoch 18: train loss 0.19449730876912463


                                                                                                                        

Epoch 19: train loss 0.18727004749977844


                                                                                                                        

Epoch 20: train loss 0.18203731046078053


                                                                                                                        

Epoch 21: train loss 0.1773040126612846


                                                                                                                        

Epoch 22: train loss 0.1725069266684512


                                                                                                                        

Epoch 23: train loss 0.1678983857022955


                                                                                                                        

Epoch 24: train loss 0.16382432326357416


                                                                                                                        

Epoch 25: train loss 0.16077640665338394


                                                                                                                        

Epoch 26: train loss 0.15843068367623267


                                                                                                                        

Epoch 27: train loss 0.15658274464150693


                                                                                                                        

Epoch 28: train loss 0.1536760980778552


                                                                                                                        

Epoch 29: train loss 0.15053733003900407


                                                                                                                        

Epoch 30: train loss 0.14932010008933697


                                                                                                                        

Epoch 31: train loss 0.1462928092226069


                                                                                                                        

Epoch 32: train loss 0.1450660435443229


                                                                                                                        

Epoch 33: train loss 0.14321620756007256


                                                                                                                        

Epoch 34: train loss 0.14118964158474132


                                                                                                                        

Epoch 35: train loss 0.1404892438903768


                                                                                                                        

Epoch 36: train loss 0.13902046987985042


                                                                                                                        

Epoch 37: train loss 0.1371648964729715


                                                                                                                        

Epoch 38: train loss 0.13572523679505002


                                                                                                                        

Epoch 39: train loss 0.13465283316500642


                                                                                                                        

Epoch 40: train loss 0.13385317236819166


                                                                                                                        

Epoch 41: train loss 0.132306074176697


                                                                                                                        

Epoch 42: train loss 0.13155715849171293


                                                                                                                        

Epoch 43: train loss 0.130068901181221


                                                                                                                        

Epoch 44: train loss 0.12965376668788017


                                                                                                                        

Epoch 45: train loss 0.12857966071113627


                                                                                                                        

Epoch 46: train loss 0.12738448058037047


                                                                                                                        

Epoch 47: train loss 0.12671609999651604


                                                                                                                        

Epoch 48: train loss 0.12551254652282026


                                                                                                                        

Epoch 49: train loss 0.12519078311767984


                                                                                                                        

Epoch 50: train loss 0.12438816659628077


                                                                                                                        

Epoch 51: train loss 0.12392515885703106


                                                                                                                        

Epoch 52: train loss 0.12180855131529747


                                                                                                                        

Epoch 53: train loss 0.12200245603601984


                                                                                                                        

Epoch 54: train loss 0.12149035746746875


                                                                                                                        

Epoch 55: train loss 0.12090098626436072


                                                                                                                        

Epoch 56: train loss 0.11971485801833741


                                                                                                                        

Epoch 57: train loss 0.11972511439247334


                                                                                                                        

Epoch 58: train loss 0.11842030074368132


                                                                                                                        

Epoch 59: train loss 0.11801072552483133


                                                                                                                        

Epoch 60: train loss 0.11768950829480557


                                                                                                                        

Epoch 61: train loss 0.11723701361011951


                                                                                                                        

Epoch 62: train loss 0.11647353920530766


                                                                                                                        

Epoch 63: train loss 0.11601450199142416


                                                                                                                        

Epoch 64: train loss 0.1155515709138931


                                                                                                                        

Epoch 65: train loss 0.11565016181545054


                                                                                                                        

Epoch 66: train loss 0.11464460901123412


                                                                                                                        

Epoch 67: train loss 0.11411364731636453


                                                                                                                        

Epoch 68: train loss 0.11389474583433029


                                                                                                                        

Epoch 69: train loss 0.11346369797879077


                                                                                                                        

Epoch 70: train loss 0.11313943713903427


                                                                                                                        

Epoch 71: train loss 0.11296797052342841


                                                                                                                        

Epoch 72: train loss 0.11198607943159469


                                                                                                                        

Epoch 73: train loss 0.11196528562205903


                                                                                                                        

Epoch 74: train loss 0.11113133474867394


                                                                                                                        

Epoch 75: train loss 0.1109364243263894


                                                                                                                        

Epoch 76: train loss 0.11084257056738468


                                                                                                                        

Epoch 77: train loss 0.11062931288430031


                                                                                                                        

Epoch 78: train loss 0.11063765851741142


                                                                                                                        

Epoch 79: train loss 0.10947750898751807


                                                                                                                        

Epoch 80: train loss 0.10941552453218623


                                                                                                                        

Epoch 81: train loss 0.10962604686934897


                                                                                                                        

Epoch 82: train loss 0.10917029431525697


                                                                                                                        

Epoch 83: train loss 0.10876119802606866


                                                                                                                        

Epoch 84: train loss 0.10847542549067356


                                                                                                                        

Epoch 85: train loss 0.10820207167813119


                                                                                                                        

Epoch 86: train loss 0.1079725203044871


                                                                                                                        

Epoch 87: train loss 0.10785106214437079


                                                                                                                        

Epoch 88: train loss 0.10749762124837713


                                                                                                                        

Epoch 89: train loss 0.10679648134936677


                                                                                                                        

Epoch 90: train loss 0.10690671706453282


                                                                                                                        

Epoch 91: train loss 0.10674774145826381


                                                                                                                        

Epoch 92: train loss 0.10636287528149625


Epoch 94/200:  51%|████████████████████████▊                        | 119/235 [00:13<00:12,  8.96it/s, Train Loss=0.106]

: 

In [None]:
# load a model
import os
model = UNet(L=n_sigma).cuda()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
eval_epoch = 70
model.load_state_dict(torch.load(f'NCSN/models/{eval_epoch}.pth'))
make_dataset(model, sigmas, eps=eps, T=T)
print("Dataset created")
os.system("python NCSN/evaluate.py")

In [4]:
epochs = 200
criterion = nn.MSELoss()
init_sigma = 1
final_sigma = 0.01
n_sigma = 10
eps = 2e-5
T = 100

sigmas = cal_noise_level(init_sigma, final_sigma, n_sigma)

model = UNet(L=n_sigma).cuda()
eval_epoch = 70
model.load_state_dict(torch.load(f"NCSN/models/{eval_epoch}.pth"))
model.to('cuda')

eps = 2e-5
clamp = False

# x = torch.rand(10, 1, 28, 28).cuda()
x = torch.randn(10, 1, 28, 28).cuda()
x = (x+1)/2
y = langevin(model, x, sigmas, eps=eps, T=T, save=True, epochs=eval_epoch, clamp=clamp)