In [1]:
import argparse
import os

import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch import nn, optim
from torch.backends import cudnn
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms.functional import to_pil_image
from ptflops import get_model_complexity_info
from tqdm import tqdm

from dataset import CTDataset,InferenceDataset
from logger import WANDBLoggerX
from networks import RED_CNN
from unet import UNet
from metrics.measure import compute_measure

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
def save_tensor_to_image(image_tensor,save_root,file_list):
    assert image_tensor.shape[0]==len(file_list),'image number should be same with filename'
    
    for i in range(len(file_list)):
        img = image_tensor[i]
        img = img.clamp(-1,1)*0.5+0.5
        img = to_pil_image(img)
        filename = file_list[i][:-4]+'_Rec.png'
        
        img.save(os.path.join(save_root,filename))
    
    

In [68]:
class config(object):
    data_root = './CT_Reconstruction_256x256_1m'
    
    # model_path = './output/unet4/L2_0427/save_models/model-0001000'
    # model_path = './output/redcnn/relu_L2_0427/save_models/model-0025000'
    model_path = './output/unet4/L1_0427/save_models/model-0025000'
    
    inference_path = './reconstruction'
    
    device = 'cuda'
    batch = 8
    num_workers = 16
    img_size = 256
    
args = config()

In [38]:
if not os.path.exists(args.inference_path):
    os.mkdir(args.inference_path)


test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=0.5,std=0.5)
])

test_dataset = InferenceDataset(args.data_root,mode='test',transform=test_transform)

print(test_dataset.__len__())

test_loader = DataLoader(test_dataset,
                            shuffle=False,
                            batch_size=args.batch,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            drop_last=True)

526


In [69]:
#model = RED_CNN()
model = UNet(repeat_num=4,conv_dim=64)

model.load_state_dict(torch.load(args.model_path,map_location='cpu'),strict=True)
model.to(args.device).eval()

macs, params = get_model_complexity_info(model, (3, args.img_size, args.img_size), as_strings=False,
                                           print_per_layer_stat=True, verbose=True)

print('{:<30}  {:.4f} GFLOPs'.format('Computational complexity:', macs/2/(10**9)))
print('{:<30}  {:<8} Params'.format('Number of parameters:', params/(10**6)))

UNet(
  24.18 M, 100.000% Params, 30.04 GFLOPS, 100.000% MACs, 
  (down_blocks): ModuleList(
    15.41 M, 63.741% Params, 19.11 GFLOPS, 63.607% MACs, 
    (0): DownBlock(
      104.58 k, 0.432% Params, 3.62 GFLOPS, 12.065% MACs, 
      (conv_res): Conv2d(256, 0.001% Params, 4.19 MFLOPS, 0.014% MACs, 3, 64, kernel_size=(1, 1), stride=(2, 2))
      (net): Sequential(
        38.72 k, 0.160% Params, 2.55 GFLOPS, 8.474% MACs, 
        (0): Conv2d(1.79 k, 0.007% Params, 117.44 MFLOPS, 0.391% MACs, 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(0, 0.000% Params, 4.19 MFLOPS, 0.014% MACs, negative_slope=0.2, inplace=True)
        (2): Conv2d(36.93 k, 0.153% Params, 2.42 GFLOPS, 8.055% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): LeakyReLU(0, 0.000% Params, 4.19 MFLOPS, 0.014% MACs, negative_slope=0.2, inplace=True)
      )
      (down): Conv2d(65.6 k, 0.271% Params, 1.07 GFLOPS, 3.577% MACs, 64, 64, kernel_size=(4, 4), stride=

In [70]:
cnt = 0
psnr,ssim,rmse = 0,0,0
with torch.no_grad():
    for test_i,(quarter_img,full_img,quarter_file,full_file) in tqdm(enumerate(test_loader),total=len(test_loader)):
        cnt += 1
        quarter_img,full_img = quarter_img.to(args.device),full_img.to(args.device)
        pred_img = model(quarter_img)
        
        save_tensor_to_image(pred_img,args.inference_path,quarter_file)
        
        result = compute_measure(pred_img,full_img)
        psnr += result['psnr']
        ssim += result['ssim']
        rmse += result['rmse']

print('reconstruction images are saved to {}'.format(args.inference_path))
print('psnr : ',psnr/cnt)
print('ssim : ',ssim/cnt)
print('rmse : ',rmse/cnt)

100%|██████████| 65/65 [00:15<00:00,  4.17it/s]

reconstruction images are saved to ./reconstruction
psnr :  tensor(39.6278, device='cuda:0')
ssim :  tensor(0.9567, device='cuda:0')
rmse :  tensor(2.7893, device='cuda:0')



