In [1]:
import numpy as np
from tqdm import tqdm
from utils import *
from utils_data import *
import argparse
import os
import random
random.seed(0)

from models import Glow
from models.glow.coupling import UNet1
import util
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='petgc')
parser.add_argument('--noise_level', type=list, default=[20]) # For PET-CT, noise_level = [PET, CT]
parser.add_argument('--semi_sup', type=bool, default=True)
parser.add_argument('--supervision', type=float, default=0.0)
parser.add_argument('--secondary_noisy', type=bool, default=False)
parser.add_argument('--resume_training', type=int, default=0)
parser.add_argument('--train_size', type=int, default=800)
parser.add_argument('--blur_mode',type=str, default=None)
parser.add_argument('--new_range',type=int, default=2)

parser.add_argument('--transfer_learning', type=bool, default=False)
parser.add_argument('--transfer_path', type=str, default='../results200_nd/unet_var_multi5/e3sgdws_petct_bpetnoperc_unet_var_ggg_multif_semi_0.0005-5000_0.5/model_400.ckpt')

parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--device', type=str, default='cuda:1')

parser.add_argument('--save', type=bool, default=False)
parser.add_argument('--path', type=str, default='../results/e3adam_')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path_fig', type=str, default='')

def str2bool(s):
    return s.lower().startswith('t')
parser.add_argument('--num_levels', '-L', default=4, type=int, help='Number of levels in the Glow model')
parser.add_argument('--num_steps', '-K', default=8, type=int, help='Number of steps of flow in each level')
parser.add_argument('--cc', type = str2bool, default = False)
parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up')
parser.add_argument('--ext', default = 'll', type=str)

args = parser.parse_args(args=[])
# args_check(args)

In [3]:
_, test_dataloader, _ = load_data(args)

(800, 1, 256, 256) (433, 1, 256, 256)
same used


In [4]:
args = create_save_path(args)

../results/e3adam_petgc_same_semi_20_0.0


In [5]:
mods = np.arange(25,501,25).astype('int')
mod_errors = []
print(mods)

[ 25  50  75 100 125 150 175 200 225 250 275 300 325 350 375 400 425 450
 475 500]


In [None]:
for mod in mods:
    model = Glow(num_channels=128,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
    
    model.load_state_dict(torch.load(args.save_path+'/model_'+str(mod)+'.ckpt', map_location='cpu'))
    print("loaded model "+str(mod))
    if torch.cuda.is_available():
        model = model.to(args.device)
        
    err = test_model(args,model,test_dataloader)
    mod_errors.append(err) 
    print(np.median(err,0))

loaded model 25


  0%|          | 0/433 [01:30<?, ?it/s]


[4.89174843e+00 6.08428122e-04 0.00000000e+00]
loaded model 50


  0%|          | 0/433 [01:32<?, ?it/s]


[1.59474039 0.10164069 0.        ]
loaded model 75


  0%|          | 0/433 [01:45<?, ?it/s]


[ 4.66963291e+00 -3.02978203e-04  0.00000000e+00]
loaded model 100


  0%|          | 0/433 [02:03<?, ?it/s]


[4.80275822e+00 4.83450098e-04 0.00000000e+00]
loaded model 125


  0%|          | 0/433 [02:05<?, ?it/s]


[2.56653738 0.0684304  0.        ]
loaded model 150


  0%|          | 0/433 [00:00<?, ?it/s]

In [None]:
with open(args.save_path+'/errors.txt', 'a') as f:
    for i,err in enumerate(mod_errors):
        f.write(str(mods[i])+': '+str(np.median(np.array(err),0))+'\n')

args.save= False
args.save_nii= False
mod = 300
model = Glow(num_channels=128,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
# model.load_state_dict(torch.load(args.save_path+'/model_'+str(mod)+'.ckpt', map_location='cpu'))
# model.load_state_dict(torch.load('../1.0_ll+sl_pl_out_0.02.pth.tar', map_location='cpu'),strict=False)
checkpoint = torch.load('../1.0_ll+sl_pl_out_0.02.pth.tar', map_location = 'cpu')
model.load_state_dict(checkpoint['net'])
if torch.cuda.is_available():
        model = model.to(args.device)
        
errors = test_model(args,model,test_dataloader)
print(np.median(errors,0))
# np.save(args.save_path+'/errors_300',errors)