In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
g_num = "4"
os.environ["CUDA_VISIBLE_DEVICES"]=g_num

In [2]:
import argparse
import numpy as np

import sys
import torch
import torch.nn as nn
import tinycudann as tcnn
from datetime import timedelta
from datetime import datetime

import imageio
import cv2
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import mean_squared_error
from skimage.metrics import structural_similarity
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import time
import random
import wandb
from torchinfo import summary
from torchmetrics.functional.image import peak_signal_noise_ratio as psnr_torch
from torchmetrics.functional.image import structural_similarity_index_measure as ssim_torch
import math
from torch.optim.optimizer import Optimizer


In [3]:
SCRIPTS_DIR = "./tiny-cuda-nn/scripts"

sys.path.append(SCRIPTS_DIR)

from common import read_image, write_image, ROOT_DIR

DATA_DIR = os.path.join(ROOT_DIR, "data")
IMAGES_DIR = os.path.join(DATA_DIR, "images")

class Image(torch.nn.Module):
	def __init__(self, filename, device):
		super(Image, self).__init__()
		self.data = read_image(filename)
		self.shape = self.data.shape
		self.data = torch.from_numpy(self.data).float().to(device)

	def forward(self, xs):
		with torch.no_grad():
			# Bilinearly filtered lookup from the image. Not super fast,
			# but less than ~20% of the overall runtime of this example.
			shape = self.shape

			xs = xs * torch.tensor([shape[1], shape[0]], device=xs.device).float()
			indices = xs.long()
			lerp_weights = xs - indices.float()

			x0 = indices[:, 0].clamp(min=0, max=shape[1]-1)
			y0 = indices[:, 1].clamp(min=0, max=shape[0]-1)
			x1 = (x0 + 1).clamp(max=shape[1]-1)
			y1 = (y0 + 1).clamp(max=shape[0]-1)

			return (
				self.data[y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
				self.data[y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
				self.data[y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
				self.data[y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
			)



def get_args():
	parser = argparse.ArgumentParser(description="Image benchmark using PyTorch bindings.")

	parser.add_argument("image", nargs="?", default="data/images/albert.jpg", help="Image to match")
	parser.add_argument("config", nargs="?", default="data/config_hash.json", help="JSON config for tiny-cuda-nn")
	parser.add_argument("n_steps", nargs="?", type=int, default=10000000, help="Number of training steps")
	parser.add_argument("result_filename", nargs="?", default="", help="Number of training steps")

	args = parser.parse_args()
	return args


In [4]:
def read_imagehsv(file):
    img = imageio.imread(file).astype(np.float32)
    for i in range(3):
        print(i,' max : ',img[:,:,i].max()," min : ", img[:,:,i].min())
    img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)

    img = np.asarray(img).astype(np.float32)
    for i in range(3):
        print(i,' max : ',img[:,:,i].max()," min : ", img[:,:,i].min())

    img[:,:,0] = img[:,:,0] / 180.0
    img[:,:,1] = img[:,:,1] * 2
    img[:,:,2] = img[:,:,2] / 127.5
    
    
    return img -1
    
class Imagehsv(torch.nn.Module):
	def __init__(self, filename, device):
		super(Imagehsv, self).__init__()
		self.data = read_imagehsv(filename)
		self.shape = self.data.shape
		self.data = torch.from_numpy(self.data).float().to(device)

	def forward(self, xs):
		with torch.no_grad():
			# Bilinearly filtered lookup from the image. Not super fast,
			# but less than ~20% of the overall runtime of this example.
			shape = self.shape

			xs = xs * torch.tensor([shape[1], shape[0]], device=xs.device).float() #hsv
			indices = xs.long()
			lerp_weights = xs - indices.float()

			x0 = indices[:, 0].clamp(min=0, max=shape[1]-1)
			y0 = indices[:, 1].clamp(min=0, max=shape[0]-1)
			x1 = (x0 + 1).clamp(max=shape[1]-1)
			y1 = (y0 + 1).clamp(max=shape[0]-1)

			return (
				self.data[y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
				self.data[y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
				self.data[y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
				self.data[y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
			)
def write_imagehsv(file, img, psnr=0.0):
    img = img.astype(np.float32)
    img += 1.0
    img[:,:,0] = img[:,:,0] * 180.0
    img[:,:,1] = img[:,:,1] /2
    img[:,:,2] = img[:,:,2] * 127.5

    
    img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
    
    img = (np.clip(img, 0.0, 255.0)).astype(np.uint8)

    if psnr > 0.1:
        file = f"{file[:-4]}_{psnr:02.4f}.png"
    
    imageio.imwrite(file, img)
    return img
def write_imagergb(file, img, quality=95):
    if img.shape[2] == 4:
        img = np.copy(img)
        # Unmultiply alpha
        img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0)
        img[...,0:3] = linear_to_srgb(img[...,0:3])
    else:
        img = linear_to_srgb(img)
    img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
    imageio.imwrite(file, img)
    return img
def linear_to_srgb(img):
	limit = 0.0031308
	return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)

In [None]:
args = argparse.Namespace( 
    image='data/images/TCGA-J2-A4AE_0.png', 
    n_steps=3000,
    result_filename='',
    save_dir='result/a_0304_4',
    batch_size = 100,
    color = 'rgb',
    device=0
)


config = {'loss': 
              {'otype': 'RelativeL2'},
          'optimizer': 
              {'otype': 'Adam',
              'learning_rate': 0.01,
              'beta1': 0.9,
              'beta2': 0.99,
              'epsilon': 1e-15,
              'l2_reg': 1e-06},
         'encoding': 
              {'otype': 'HashGrid',
               'n_levels': 24,
              'n_features_per_level': 2,
              'log2_hashmap_size': 24,
              'base_resolution': 4,
              'per_level_scale': 1.5},
          'network': 
              {
              'otype': 'FullyFusedMLP',
              'activation': "ReLU", # 'ReLU',"Sine"
              'output_activation': "None",
              'n_neurons': 128,
              'n_hidden_layers': 2,
              }
         }
wandb.init(
            project = "pathology Neural Representation",
            
            notes = "image 1 in jupyter, hash( siren) ,rgb, cnn ",

            config={
                "date" : datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "image": args.image,
                
                "batch size": args.batch_size,
                "epochs" : args.n_steps,
                "color" : args.color,
                "gpu_num": g_num
            }
        )
# define our custom x axis metric
wandb.define_metric("accumulated_study_time")
# set all other train/ metrics to use this step
wandb.define_metric("time/*", step_metric="accumulated_study_time")

In [6]:
args.image

'data/images/TCGA-J2-A4AE_0.png'

In [7]:
print("================================================================")
print("This script replicates the behavior of the native CUDA example  ")
print("mlp_learning_an_image.cu using tiny-cuda-nn's PyTorch extension.")
print("================================================================")

print(f"Using PyTorch version {torch.__version__} with CUDA {torch.version.cuda}")

device = torch.device("cuda:"+str(args.device))

print(os.listdir())

if args.color == 'rgb':
    image = Image(args.image, device)
elif args.color == 'hsv':
    image = Imagehsv(args.image, "cpu")
n_channels = image.data.shape[2]


#siren rgb
image.data = (image.data *2) -1

for i in range(3):
    print(image.data[:,:,i].min())
    print(image.data[:,:,i].max())
print(image.shape)

This script replicates the behavior of the native CUDA example  
mlp_learning_an_image.cu using tiny-cuda-nn's PyTorch extension.
Using PyTorch version 2.1.0 with CUDA 12.1
['tcnn_t_1122.ipynb', '.ipynb_checkpoints', 'result', 'data', 'tiny_cuda_nn_test_1106_1.ipynb', 'tcnn_t_1122_2.ipynb', 'mlp_cnn test.ipynb', 'save3.png', 'save3_1.png', 'tcnn_0108_cnn_1.ipynb', 'tcnn_1205_cnn.ipynb', 'tcnn_1206_cnn_1.ipynb', 'tcnn_0110_cnn_1.ipynb', 'tcnn_1220_cnn_1.ipynb', 'tcnn_1206_cnn_2.ipynb', 'tcnn_1206_cnn_3.ipynb', 'tcnn_1207_cnn_1.ipynb', 'tcnn_1207_cnn_2.ipynb', 'tcnn_1207_cnn_2_1.ipynb', 'tcnn_1207_cnn_3.ipynb', 'tcnn_1207_cnn_4.ipynb', 'tcnn_1207_cnn_5.ipynb', 'tcnn_1208_cnn_1.ipynb', 'tcnn_1208_cnn_2.ipynb', 'tcnn_1208_cnn_3.ipynb', 'tcnn_1207_cnn_5-Copy1.ipynb', 'tcnn_1207_cnn_5_1.ipynb', 'tcnn_1207_cnn_5_2.ipynb', 'tcnn_1208_cnn_4.ipynb', 'tcnn_1210_cnn_1.ipynb', 'tcnn_1210_cnn_2.ipynb', 'tcnn_1210_cnn_3.ipynb', 'tcnn_1210_cnn_4.ipynb', 'tcnn_1211_cnn_1.ipynb', 'tcnn_1211_cnn_2.ipynb'



tensor(-1., device='cuda:0')
tensor(1., device='cuda:0')
tensor(-1., device='cuda:0')
tensor(1., device='cuda:0')
tensor(-1., device='cuda:0')
tensor(1., device='cuda:0')
(10000, 10000, 3)


In [9]:
class Sine(nn.Module):
    def __init__(self, w0 = 1.):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

class SineCNN(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=False,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        # self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.cnn = torch.nn.Conv2d(in_features, out_features, (3, 3), stride=(1, 1), padding=(1, 1), bias=bias, dtype =torch.half)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                torch.nn.init.uniform_(self.cnn.weight, -1 / self.in_features, 
                                             1 / self.in_features)
            else:
                torch.nn.init.uniform_(self.cnn.weight, -torch.sqrt(torch.tensor(6) / torch.tensor(self.in_features)) / self.omega_0, 
                                             torch.sqrt(torch.tensor(6) / torch.tensor(self.in_features)) / self.omega_0)

    
    def forward(self, input):
        return torch.sin(self.omega_0 * self.cnn(input))


class my_hash(nn.Module):
    def __init__(self, n_input_dims=2, n_output_dims=n_channels, encoding_config=config["encoding"], network_config=config["network"]) -> None:
        super().__init__()
        self.encoder = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"])

        self.h = 100
        self.w = 100
        self.batch = 2

        # non-square kernels and unequal stride and with padding and dilation
        
        self.cnn_dim = 128
        self.cnn = torch.nn.Sequential(
            SineCNN(self.encoder.n_output_dims, self.cnn_dim, is_first=True),
            
            
            SineCNN(self.cnn_dim, self.cnn_dim, is_first=False),
            
        )
        
        self.mlp = tcnn.Network(n_input_dims=self.encoder.n_output_dims, n_output_dims=128, network_config=config["network"])
        self.cnn2 = torch.nn.Sequential(
            
            SineCNN(256, self.cnn_dim, is_first=True),
            
            
            SineCNN(self.cnn_dim, self.cnn_dim, is_first=False),
            
        )
        
        self.mlp2 = tcnn.Network(n_input_dims=128, n_output_dims=128, network_config=config["network"])
        self.cnn_last = torch.nn.Sequential(
            SineCNN(256, 3, is_first=False)
            
        )
        
    
    def forward(self, x):
        self.batch = x.size(dim=0)
        self.h = x.size(dim=1)
        self.w = x.size(dim=2)
        
        x = torch.reshape(x, (-1,2))
        
        x_encode = self.encoder(x)
        x_cnn = x_encode.reshape(self.batch,self.h , self.w, -1)
        
        x_cnn = x_cnn.permute(0,3,1,2)
        
        x_cnn = self.cnn(x_cnn)
        
        x_mlp = self.mlp(x_encode)
        x_mlp_ = x_mlp.reshape(self.batch, self.h , self.w, -1)
        x_mlp_ = x_mlp_.permute(0,3,1,2)
        x_cnn = torch.concat((x_cnn,x_mlp_), dim=1)
        

        x_cnn = self.cnn2(x_cnn)
        x_mlp = self.mlp2(x_mlp)
        x_mlp_ = x_mlp.reshape(self.batch, self.h , self.w, -1)
        x_mlp_ = x_mlp_.permute(0,3,1,2)
        x = torch.concat((x_cnn,x_mlp_), dim=1)
        
        x = self.cnn_last(x)
        return x

model = my_hash(n_input_dims=2, n_output_dims=n_channels, encoding_config=config["encoding"], network_config=config["network"]).to(device) #.to(device, dtype=torch.half)
print(model)



my_hash(
  (encoder): Encoding(n_input_dims=2, n_output_dims=48, seed=1337, dtype=torch.float16, hyperparams={'base_resolution': 4, 'hash': 'CoherentPrime', 'interpolation': 'Linear', 'log2_hashmap_size': 24, 'n_features_per_level': 2, 'n_levels': 24, 'otype': 'Grid', 'per_level_scale': 1.5, 'type': 'Hash'})
  (cnn): Sequential(
    (0): SineCNN(
      (cnn): Conv2d(48, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): SineCNN(
      (cnn): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
  (mlp): Network(n_input_dims=48, n_output_dims=128, seed=1337, dtype=torch.float16, hyperparams={'encoding': {'offset': 0.0, 'otype': 'Identity', 'scale': 1.0}, 'network': {'activation': 'ReLU', 'n_hidden_layers': 2, 'n_neurons': 128, 'otype': 'FullyFusedMLP', 'output_activation': 'None'}, 'otype': 'NetworkWithInputEncoding'})
  (cnn2): Sequential(
    (0): SineCNN(
      (cnn): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, betas=(0.9, 0.99),   fused=torch.float16)


print(summary(model, (args.batch_size, 128, 128,2)))
print("\nOptimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
    del var_name


# Variables for saving/displaying image results
resolution = image.data.shape[0:2]
img_shape = resolution + torch.Size([image.data.shape[2]])
n_pixels = resolution[0] * resolution[1]

half_dx =  0.5 / resolution[0]
half_dy =  0.5 / resolution[1]
xs = torch.linspace(half_dx, 1-half_dx, resolution[0], device=device)
ys = torch.linspace(half_dy, 1-half_dy, resolution[1], device=device)
xv, yv = torch.meshgrid([xs, ys])

# xy = torch.stack((yv.flatten(), xv.flatten())).t()

Layer (type:depth-idx)                   Output Shape              Param #
my_hash                                  [100, 3, 128, 128]        --
├─Encoding: 1-1                          [1638400, 48]             257,270,128
├─Sequential: 1-2                        [100, 128, 128, 128]      --
│    └─SineCNN: 2-1                      [100, 128, 128, 128]      --
│    │    └─Conv2d: 3-1                  [100, 128, 128, 128]      55,296
│    └─SineCNN: 2-2                      [100, 128, 128, 128]      --
│    │    └─Conv2d: 3-2                  [100, 128, 128, 128]      147,456
├─Network: 1-3                           [1638400, 128]            38,912
├─Sequential: 1-4                        [100, 128, 128, 128]      --
│    └─SineCNN: 2-3                      [100, 128, 128, 128]      --
│    │    └─Conv2d: 3-3                  [100, 128, 128, 128]      294,912
│    └─SineCNN: 2-4                      [100, 128, 128, 128]      --
│    │    └─Conv2d: 3-4                  [100, 128, 128, 1

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [11]:
xy = torch.stack((yv, xv)).permute(1,2,0) 
xy = xy.unsqueeze(0)

In [12]:
xy.shape

torch.Size([1, 10000, 10000, 2])

In [13]:
xy[0,-2:,-2:,:]

tensor([[[0.9998, 0.9998],
         [0.9999, 0.9998]],

        [[0.9998, 0.9999],
         [0.9999, 0.9999]]], device='cuda:0')

In [14]:
path = f"result/reference.png"
print(f"Writing '{path}'... ", end="")
# write_image(path, image(xy).reshape(img_shape).detach().cpu().numpy())
# print("done.")

prev_time = time.perf_counter()


interval = 10

print("================================================================")
total_params = 0
print("Model's state_dict:")
for param_tensor in model.state_dict():
    param = 1
    for p in  model.state_dict()[param_tensor].size():
        param *= p
        del p
    print(param_tensor, "\t", model.state_dict()[param_tensor].size(), "\t", param)
    total_params += param
    del param
print("Model's total params : ", total_params)
print("Model's bpp : ", total_params*16/(image.data.shape[0]*image.data.shape[1]) ,"\n")
wandb.config.update({"Model params": total_params, "Model bpp": total_params*16/(image.data.shape[0]*image.data.shape[1])})
del total_params

Model's state_dict:
encoder.params 	 torch.Size([257270128]) 	 257270128
cnn.0.cnn.weight 	 torch.Size([128, 48, 3, 3]) 	 55296
cnn.1.cnn.weight 	 torch.Size([128, 128, 3, 3]) 	 147456
mlp.params 	 torch.Size([38912]) 	 38912
cnn2.0.cnn.weight 	 torch.Size([128, 256, 3, 3]) 	 294912
cnn2.1.cnn.weight 	 torch.Size([128, 128, 3, 3]) 	 147456
mlp2.params 	 torch.Size([49152]) 	 49152
cnn_last.0.cnn.weight 	 torch.Size([3, 256, 3, 3]) 	 6912
Model's total params :  258010224
Model's bpp :  41.28163584 



In [15]:
class nrDataset(Dataset):
    def __init__(self, xy, image, image_size = 128, r=0.8, g=0.2 , rand=None):
        self.image = image.squeeze(0)
        self.xy = xy.squeeze(0)
        self.image_size = image_size
        self.r = r * image_size
        self.g = g * image_size
        self.rand = rand
        c, self.h, self.w = self.image.shape
        self.hs = math.ceil(self.h/(self.r))
        self.ws = math.ceil(self.w/(self.r))

    def __len__(self):
        return self.hs * self.ws

    def __getitem__(self, idx):
        if idx % self.ws == self.ws -1:
            x_start = self.w - self.image_size
        else:
            x_start = (idx % self.ws ) * self.r
            if self.rand != None and idx % self.ws != 0:
                x_start = x_start + ((random.random() - 0.5) * self.g)
        if idx // self.hs == self.hs -1:
            y_start = self.h - self.image_size
        else:
            y_start = (idx // self.ws) * self.r
            if self.rand != None and idx // self.hs != 0:
                y_start = y_start + ((random.random() - 0.5) * self.g)

        x_start = round(x_start)
        y_start = round(y_start)
        # print(x_start,y_start)
        
        image_ = self.image[:,y_start:y_start+self.image_size,x_start:x_start+self.image_size]
        xy_ = self.xy[y_start:y_start+self.image_size,x_start:x_start+self.image_size,:]
        
        # print(        self.image_size,        self.r,        self.g ,        self.rand ,        self.h, self.w,        self.hs,        self.ws)
        return xy_, image_


t_image = image.data.to(torch.half)
t_image = t_image.permute(2,0,1).unsqueeze(0)
dataset = nrDataset(xy, t_image, image_size = 128, r=0.8, g=0.2 , rand=True)
# dataset = nrDataset(xy, t_image.permute(2,0,1).unsqueeze(0), image_size = 128, r=0.99, g=0 )
train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# x,y = next(iter(train_dataloader))


In [None]:
print(f"Beginning optimization with {args.n_steps} training steps.")

try:
    batch, batch_label = next(iter(train_dataloader))
    traced_image = torch.jit.trace(model, batch)
except:
    # If tracing causes an error, fall back to regular execution
    print(f"WARNING: PyTorch JIT trace failed. Performance will be slightly worse than regular.")
    # traced_image = image
    # traced_image = t_image



# os.makedirs(os.path.dirname(args.save_dir), exist_ok=True)
os.makedirs(args.save_dir, exist_ok=True)
dir_len = len(os.listdir(args.save_dir))
os.makedirs(os.path.join(args.save_dir, str(dir_len)), exist_ok=True)
print('dir = ', os.path.join(args.save_dir, str(dir_len)))

psnr_origin = image.data.type(torch.float32).cpu().numpy()
now = datetime.now()
print("시작 시간 : ", now.strftime('%Y-%m-%d %H:%M:%S'))
start_time = time.perf_counter()
accumulated_study_time = 0
for i in range(args.n_steps):
    loss = 0.0
    ssim = 0.0
    psnr = 0.0
    stop = False
    train_start_time = time.perf_counter()
    for x, targets in train_dataloader:
        x = x.to(device)
        targets = targets.to(device)
        # targets += 1.0

        output = model(x)
        mse = nn.MSELoss()
        loss_ = mse(output, targets)

        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()

        if torch.isnan(loss_):
            print('steps : ',i, '  nan | last_loss : ', loss ," | total_time : ", timedelta(seconds=int(time.perf_counter() - start_time)), " [h:m:s]")
            stop = True
            break

        loss += loss_.item()
        ssim += ssim_torch(output, targets).item()
        psnr += psnr_torch(output, targets, data_range=(0,1)).item()


    loss = loss / len(train_dataloader)
    ssim = ssim / len(train_dataloader)
    psnr = psnr / len(train_dataloader)
    torch.cuda.synchronize()
    train_end_time = time.perf_counter()
    train_time_per_epoch = train_end_time-train_start_time
    time_from_start = train_end_time-start_time
    accumulated_study_time += train_time_per_epoch

    
    if stop:
        break

    print('steps : ',i, '  loss : ', loss," | ssim",ssim , " | total_time : ", timedelta(seconds=int(time.perf_counter() - start_time)), " [h:m:s]")

            

    if i % interval == 0 or i == (args.n_steps - 1) or i == 1 or i == 2 or i % 250 == 0: 
        loss_val = loss
        torch.cuda.synchronize()
        elapsed_time = time.perf_counter() - prev_time
        print(f"\rStep#{i}: loss={loss_val} / {timedelta(seconds=int(elapsed_time))}[h:m:s] {int(elapsed_time)}[s] /time={int(elapsed_time*1000000)}[µs]")

        path = f"{args.save_dir}/{dir_len}/{i:07.0f}.png"
        print(f"Writing '{path}'... ", end="")
        write_image_start_time = time.perf_counter() 
        if args.color == 'rgb': 
            with torch.no_grad():
                pred_img = np.zeros((t_image.shape), dtype=np.float32)
                for y_ in range(10):
                    for x_ in range(10):
                        out_image = model(xy[:,y_*1000:(y_+1)*1000, x_*1000:(x_+1)*1000,:])
                        out_image = out_image.detach().cpu().numpy()
                        pred_img[0,:,y_*1000:(y_+1)*1000, x_*1000:(x_+1)*1000] = out_image
                output_image = np.squeeze(pred_img, 0)
                output_image = np.transpose(output_image, (1,2,0))
                output_image = np.clip(output_image, -1.0, 1.0) 


                rgb_loss = float(mean_squared_error(psnr_origin, output_image))
                rgb_ssim = ssim_torch(torch.from_numpy(psnr_origin+1).permute(2,0,1).unsqueeze(0), torch.from_numpy(output_image+1).permute(2,0,1).unsqueeze(0)).item()
                rgb_psnr = float(peak_signal_noise_ratio(psnr_origin, output_image))
                output_img= write_imagergb(path, (output_image+1)/2)
                
        elif args.color == 'hsv':
            with torch.no_grad():
                pred_img = np.zeros((t_image.shape), dtype=np.float32)
                for y_ in range(10):
                    for x_ in range(10):
                        out_image = model(xy[:,y_*1000:(y_+1)*1000, x_*1000:(x_+1)*1000,:])
                        out_image = out_image.detach().cpu().numpy()
                        pred_img[0,:,y_*1000:(y_+1)*1000, x_*1000:(x_+1)*1000] = out_image
                output_image = np.squeeze(pred_img, 0)
                output_image = np.transpose(output_image, (1,2,0))
                # output_image = np.clip(output_image, -1.0, 1.0) 
                output_image = np.clip(output_image, 0.0, 2.0) 
                # not siren
                output_image -= 1.0

                rgb_loss = float(mean_squared_error(psnr_origin, output_image))
                rgb_ssim = ssim_torch(torch.from_numpy(psnr_origin+1).permute(2,0,1).unsqueeze(0), torch.from_numpy(output_image+1).permute(2,0,1).unsqueeze(0)).item()
                rgb_psnr = float(peak_signal_noise_ratio(psnr_origin, output_image))
                output_img= write_imagehsv(path, output_image, rgb_psnr)
    
                
                
        write_image_time = time.perf_counter() - write_image_start_time
        print("done. | MSEloss :", rgb_loss," | PSNR : ", rgb_psnr, " | write_time : ", timedelta(seconds=int(write_image_time)), " [h:m:s] | total_time : ", timedelta(seconds=int(time.perf_counter() - start_time)), " [h:m:s]")
        wandb.log({"step":i,"hsv_loss": loss, "hsv_ssim": ssim, "hsv_psnr": psnr,  "accumulated_study_time": accumulated_study_time, "time per epoch": train_time_per_epoch,
                               "time/step":i,"time/hsv_loss": loss, "time/hsv_ssim": ssim, "time/hsv_psnr": psnr,
                               "rgb_loss": rgb_loss, "rgb_ssim": rgb_ssim, "rgb_psnr": rgb_psnr,
                               "result_image": wandb.Image(output_img)})

        # Ignore the time spent saving the image
        prev_time = time.perf_counter()

        if i != 1 and i !=2:
            if i > 0 and interval < 10000:
                interval *= 10
    else:
        wandb.log({"step":i,"hsv_loss": loss, "hsv_ssim": ssim, "hsv_psnr": psnr, "accumulated_study_time": accumulated_study_time, "time per epoch": train_time_per_epoch,
                               "time/step":i,"time/hsv_loss": loss, "time/hsv_ssim": ssim, "time/hsv_psnr": psnr})


# tcnn.free_temporary_memory()


Beginning optimization with 3000 training steps.


  x_padded = x if batch_size == padded_batch_size else torch.nn.functional.pad(x, [0, 0, 0, padded_batch_size - batch_size])


dir =  result/a_0304_4/0
시작 시간 :  2024-03-04 18:31:59
steps :  0   loss :  0.18421039384664947  | ssim 0.3567599562025562  | total_time :  0:00:32  [h:m:s]
Step#0: loss=0.18421039384664947 / 0:00:33[h:m:s] 33[s] /time=33803162[µs]
Writing 'result/a_0304_4/0/0000000.png'... done. | MSEloss : 0.06847115507234014  | PSNR :  17.665523370795675  | write_time :  0:01:02  [h:m:s] | total_time :  0:01:35  [h:m:s]
steps :  1   loss :  0.05906189594072165  | ssim 0.5445846085695877  | total_time :  0:02:41  [h:m:s]
Step#1: loss=0.05906189594072165 / 0:00:32[h:m:s] 32[s] /time=32680735[µs]
Writing 'result/a_0304_4/0/0000001.png'... done. | MSEloss : 0.042308924951415944  | PSNR :  19.756280010616518  | write_time :  0:01:01  [h:m:s] | total_time :  0:03:43  [h:m:s]
steps :  2   loss :  0.03811142125080541  | ssim 0.6146454172036082  | total_time :  0:04:47  [h:m:s]
Step#2: loss=0.03811142125080541 / 0:00:32[h:m:s] 32[s] /time=32339184[µs]
Writing 'result/a_0304_4/0/0000002.png'... done. | MSEloss

In [None]:
output_img.shape