In [1]:
import numpy as np
from tqdm import tqdm
from utils import *
from utils_data import *
import argparse
import os
import random
random.seed(0)

from models import Glow
from models.glow.coupling import UNet1
import util
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='ct')
parser.add_argument('--noise_level', type=list, default=[5000]) # For PET-CT, noise_level = [PET, CT]
parser.add_argument('--semi_sup', type=bool, default=True)
parser.add_argument('--supervision', type=float, default=0.0)
parser.add_argument('--secondary_noisy', type=int, default=1)
parser.add_argument('--resume_training', type=int, default=0)
parser.add_argument('--train_size', type=int, default=200)
parser.add_argument('--blur_mode',type=str, default=None)
parser.add_argument('--new_range',type=int, default=2)

parser.add_argument('--transfer_learning', type=bool, default=False)
parser.add_argument('--transfer_path', type=str, default='../results200_nd/unet_var_multi5/e3sgdws_petct_bpetnoperc_unet_var_ggg_multif_semi_0.0005-5000_0.5/model_400.ckpt')

parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--device', type=str, default='cuda:2')

parser.add_argument('--save', type=bool, default=False)
parser.add_argument('--path', type=str, default='../results/e3adam_')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path_fig', type=str, default='')

def str2bool(s):
    return s.lower().startswith('t')
parser.add_argument('--num_levels', '-L', default=4, type=int, help='Number of levels in the Glow model')
parser.add_argument('--num_steps', '-K', default=8, type=int, help='Number of steps of flow in each level')
parser.add_argument('--cc', type = str2bool, default = False)
parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up')
parser.add_argument('--ext', default = 'll', type=str)

args = parser.parse_args(args=[])
# args_check(args)

In [3]:
_, test_dataloader, _ = load_data(args)

(200, 1, 512, 512) (326, 1, 512, 512)
1 secondary taget used


In [4]:
args = create_save_path(args)

../results/e3adam_ct_semi_5000_0.0


In [5]:
mods = np.arange(25,501,25).astype('int')
mod_errors = []
print(mods)

[ 25  50  75 100 125 150 175 200 225 250 275 300 325 350 375 400 425 450
 475 500]


In [6]:
for mod in mods:
    model = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)

    model.load_state_dict(torch.load(args.save_path+'/model_'+str(mod)+'.ckpt', map_location='cpu'))
    print("loaded model "+str(mod))
    if torch.cuda.is_available():
        model = model.to(args.device)
        
    err = test_model(args,model,test_dataloader)
    mod_errors.append(err) 
    print(np.median(err,0))

loaded model 25


100%|██████████| 326/326 [01:20<00:00,  4.03it/s]


[0.09308659 0.58869824 0.        ]
loaded model 50


100%|██████████| 326/326 [01:19<00:00,  4.08it/s]


[0.09097944 0.58947625 0.        ]
loaded model 75


100%|██████████| 326/326 [01:19<00:00,  4.08it/s]


[0.08977844 0.60091038 0.        ]
loaded model 100


100%|██████████| 326/326 [01:20<00:00,  4.05it/s]


[0.09156996 0.57070493 0.        ]
loaded model 125


100%|██████████| 326/326 [01:19<00:00,  4.12it/s]


[0.09332922 0.5050998  0.        ]
loaded model 150


100%|██████████| 326/326 [01:19<00:00,  4.09it/s]


[0.09153318 0.58707855 0.        ]
loaded model 175


100%|██████████| 326/326 [01:19<00:00,  4.11it/s]


[0.09294662 0.52844128 0.        ]
loaded model 200


100%|██████████| 326/326 [01:19<00:00,  4.10it/s]


[0.09236951 0.59443132 0.        ]
loaded model 225


100%|██████████| 326/326 [01:19<00:00,  4.08it/s]


[0.09111975 0.59375355 0.        ]
loaded model 250


100%|██████████| 326/326 [01:20<00:00,  4.07it/s]


[0.0918774  0.54133534 0.        ]
loaded model 275


100%|██████████| 326/326 [01:19<00:00,  4.12it/s]


[0.09161341 0.58475275 0.        ]
loaded model 300


100%|██████████| 326/326 [01:19<00:00,  4.10it/s]


[0.09259629 0.53348209 0.        ]
loaded model 325


100%|██████████| 326/326 [01:19<00:00,  4.12it/s]


[0.09199211 0.56796941 0.        ]
loaded model 350


100%|██████████| 326/326 [01:21<00:00,  3.98it/s]


[0.09150258 0.56122712 0.        ]
loaded model 375


100%|██████████| 326/326 [01:20<00:00,  4.03it/s]


[0.09050513 0.59969379 0.        ]
loaded model 400


100%|██████████| 326/326 [01:20<00:00,  4.05it/s]


[0.09070345 0.60230673 0.        ]
loaded model 425


100%|██████████| 326/326 [01:20<00:00,  4.07it/s]


[0.09054676 0.58532464 0.        ]
loaded model 450


100%|██████████| 326/326 [01:19<00:00,  4.09it/s]


[0.09123593 0.57374065 0.        ]
loaded model 475


100%|██████████| 326/326 [01:19<00:00,  4.08it/s]


[0.09083655 0.60279985 0.        ]
loaded model 500


100%|██████████| 326/326 [01:20<00:00,  4.03it/s]

[0.09133236 0.58593808 0.        ]





In [7]:
with open(args.save_path+'/errors.txt', 'a') as f:
    for i,err in enumerate(mod_errors):
        f.write(str(mods[i])+': '+str(np.median(np.array(err),0))+'\n')

In [None]:
args.save= True
args.save_nii= False
mod = 300
model = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
model.load_state_dict(torch.load(args.save_path+'/model_'+str(mod)+'.ckpt', map_location='cpu'))
if torch.cuda.is_available():
        model = model.to(args.device)
        
errors = test_model(args,model,test_dataloader)
print(np.median(errors,0))
np.save(args.save_path+'/errors_300',errors)

 79%|███████▉  | 257/326 [01:34<00:24,  2.84it/s]