In [1]:
import numpy as np
from tqdm import tqdm
from utils import *
from utils_data import *
import argparse
import os
import random
random.seed(0)

from models import Glow
from models.glow.coupling import UNet1
import util
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='ct')
parser.add_argument('--noise_level', type=list, default=[5000]) # For PET-CT, noise_level = [PET, CT]
parser.add_argument('--semi_sup', type=bool, default=True)
parser.add_argument('--supervision', type=float, default=0.0)
parser.add_argument('--secondary_noisy', type=bool, default=False)
parser.add_argument('--resume_training', type=int, default=0)
parser.add_argument('--train_size', type=int, default=200)
parser.add_argument('--blur_mode',type=str, default=None)
parser.add_argument('--new_range',type=int, default=2)

parser.add_argument('--transfer_learning', type=bool, default=False)
parser.add_argument('--transfer_path', type=str, default='../results200_nd/unet_var_multi5/e3sgdws_petct_bpetnoperc_unet_var_ggg_multif_semi_0.0005-5000_0.5/model_400.ckpt')

parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--device', type=str, default='cuda:3')

parser.add_argument('--save', type=bool, default=False)
parser.add_argument('--path', type=str, default='../results/e3adam_')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path_fig', type=str, default='')

def str2bool(s):
    return s.lower().startswith('t')
parser.add_argument('--num_levels', '-L', default=4, type=int, help='Number of levels in the Glow model')
parser.add_argument('--num_steps', '-K', default=8, type=int, help='Number of steps of flow in each level')
parser.add_argument('--cc', type = str2bool, default = False)
parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up')
parser.add_argument('--ext', default = 'll', type=str)

args = parser.parse_args(args=[])
# args_check(args)

In [3]:
_, test_dataloader, _ = load_data(args)

(200, 1, 512, 512) (326, 1, 512, 512)
same used


In [4]:
args = create_save_path(args)

../results/e3adam_ct_same_semi_5000_0.0


In [5]:
mods = np.arange(25,501,25).astype('int')
mod_errors = []
print(mods)

[ 25  50  75 100 125 150 175 200 225 250 275 300 325 350 375 400 425 450
 475 500]


In [6]:
for mod in mods:
    model = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)

    model.load_state_dict(torch.load(args.save_path+'/model_'+str(mod)+'.ckpt', map_location='cpu'))
    print("loaded model "+str(mod))
    if torch.cuda.is_available():
        model = model.to(args.device)
        
    err = test_model(args,model,test_dataloader)
    mod_errors.append(err) 
    print(np.median(err,0))

loaded model 25


100%|██████████| 326/326 [01:31<00:00,  3.58it/s]


[0.17927298 0.19543545 0.        ]
loaded model 50


100%|██████████| 326/326 [01:34<00:00,  3.45it/s]


[0.11959054 0.35728384 0.        ]
loaded model 75


100%|██████████| 326/326 [01:30<00:00,  3.61it/s]


[0.10131329 0.49519941 0.        ]
loaded model 100


100%|██████████| 326/326 [01:31<00:00,  3.57it/s]


[0.09011939 0.50084336 0.        ]
loaded model 125


100%|██████████| 326/326 [01:28<00:00,  3.67it/s]


[0.09696758 0.48906342 0.        ]
loaded model 150


100%|██████████| 326/326 [01:30<00:00,  3.61it/s]


[0.25290568 0.14714801 0.        ]
loaded model 175


100%|██████████| 326/326 [01:30<00:00,  3.60it/s]


[0.14304368 0.31615335 0.        ]
loaded model 200


100%|██████████| 326/326 [01:31<00:00,  3.54it/s]


[0.09149094 0.55271951 0.        ]
loaded model 225


100%|██████████| 326/326 [01:30<00:00,  3.60it/s]


[0.09381359 0.52466382 0.        ]
loaded model 250


100%|██████████| 326/326 [01:32<00:00,  3.52it/s]


[0.56237742 0.06478868 0.        ]
loaded model 275


100%|██████████| 326/326 [01:31<00:00,  3.54it/s]


[0.0876531  0.61254217 0.        ]
loaded model 300


100%|██████████| 326/326 [01:29<00:00,  3.65it/s]


[0.08944085 0.55885146 0.        ]
loaded model 325


100%|██████████| 326/326 [01:30<00:00,  3.58it/s]


[0.0879802  0.54635813 0.        ]
loaded model 350


100%|██████████| 326/326 [01:31<00:00,  3.57it/s]


[0.08731568 0.58234704 0.        ]
loaded model 375


100%|██████████| 326/326 [01:32<00:00,  3.54it/s]


[0.09013198 0.51121455 0.        ]
loaded model 400


100%|██████████| 326/326 [01:32<00:00,  3.52it/s]


[0.08824774 0.54893965 0.        ]
loaded model 425


100%|██████████| 326/326 [01:34<00:00,  3.46it/s]


[0.08745656 0.6021656  0.        ]
loaded model 450


100%|██████████| 326/326 [01:43<00:00,  3.15it/s]


[0.08746468 0.59796164 0.        ]
loaded model 475


100%|██████████| 326/326 [01:43<00:00,  3.15it/s]


[0.08832715 0.52775365 0.        ]
loaded model 500


100%|██████████| 326/326 [01:43<00:00,  3.14it/s]

[0.08805778 0.6202136  0.        ]





In [7]:
with open(args.save_path+'/errors.txt', 'a') as f:
    for i,err in enumerate(mod_errors):
        f.write(str(mods[i])+': '+str(np.median(np.array(err),0))+'\n')

args.save= True
args.save_nii= False
mod = 300
model = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
model.load_state_dict(torch.load(args.save_path+'/model_'+str(mod)+'.ckpt', map_location='cpu'))
if torch.cuda.is_available():
        model = model.to(args.device)
        
errors = test_model(args,model,test_dataloader)
print(np.median(errors,0))
np.save(args.save_path+'/errors_300',errors)