In [1]:
import numpy as np
from tqdm import tqdm
from utils import *
from utils_data import *
import argparse
import os
import random
random.seed(0)

from models import Glow
from models.glow.coupling import UNet1
import util
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='ct')
parser.add_argument('--noise_level', type=list, default=[5000]) # For PET-CT, noise_level = [PET, CT]
parser.add_argument('--semi_sup', type=bool, default=True)
parser.add_argument('--supervision', type=float, default=0.0)
parser.add_argument('--secondary_noisy', type=int, default=0)
parser.add_argument('--resume_training', type=int, default=0)
parser.add_argument('--train_size', type=int, default=200)
parser.add_argument('--blur_mode',type=str, default=None)
parser.add_argument('--new_range',type=int, default=2)

parser.add_argument('--transfer_learning', type=bool, default=False)
parser.add_argument('--transfer_path', type=str, default='../results200_nd/unet_var_multi5/e3sgdws_petct_bpetnoperc_unet_var_ggg_multif_semi_0.0005-5000_0.5/model_400.ckpt')

parser.add_argument('--n_samples', type=int, default=1)
parser.add_argument('--s_samples', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--epoch_num', type=int, default=500)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--device', type=str, default='cuda:4')
parser.add_argument('--perceptual', type=bool, default=False)
parser.add_argument('--weights',type=tuple,default=[1, 1, 1]) #(pet, ct, latent)
parser.add_argument('--target',type=str,default=None)

eps=.1
alp=.01
it=5
p='inf'
parser.add_argument('--save', type=bool, default=True)
parser.add_argument('--path', type=str, default='../results_adv/e3adam_e'+str(eps)+'_a'+str(alp)+'_i'+str(it)+'_p'+str(p)+'_')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path_fig', type=str, default='')

def str2bool(s):
    return s.lower().startswith('t')
parser.add_argument('--num_levels', '-L', default=4, type=int, help='Number of levels in the Glow model')
parser.add_argument('--num_steps', '-K', default=8, type=int, help='Number of steps of flow in each level')
parser.add_argument('--cc', type = str2bool, default = False)
parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up')
parser.add_argument('--ext', default = 'll', type=str)

args = parser.parse_args(args=[])
# args_check(args)

In [3]:
trainloader, testloader, validloader = load_data(args)

(200, 1, 512, 512) (326, 1, 512, 512)
same used


In [4]:
args = create_save_path(args)

../results_adv/e3adam_e0.1_a0.01_i5_pinf_ct_same_semi_5000_0.0
Create path : ../results_adv/e3adam_e0.1_a0.01_i5_pinf_ct_same_semi_5000_0.0


In [5]:
net = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
net = net.to(args.device)
cudnn.benchmark = True
# if device == 'cuda':
#     net = torch.nn.DataParallel(net, args.gpu_ids)
#     cudnn.benchmark = args.benchmark

if args.resume_training!=0:
    net.load_state_dict(torch.load(args.save_path+'/model_'+str(args.resume_training)+'.ckpt', map_location='cpu'))

unet = UNet1(inp_channels=1, op_channels=1)
unet = unet.to(args.device)
# unet_weights = torch.load('ckpts/unet/best.pth', map_location = args.device)
# unet.load_state_dict(unet_weights)

In [6]:
def norm_l2(Z):
    """Compute norms over all but the first dimension"""
    return Z.view(Z.shape[0], -1).norm(dim=1)[:,None,None,None]/np.sqrt(Z.shape[1]*Z.shape[2]*Z.shape[3])

def pgd_linf_manual(model, x, cond_x, loss_fn, reverse, epsilon=eps, alpha=alp, num_iter=it, randomize=True, norm=p):
    """ Construct FGSM adversarial examples on the examples X"""
    if randomize:
        delta = torch.rand_like(x, requires_grad=True)
        delta.data = delta.data * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(x, requires_grad=True)
#     losses = []    
    for t in range(num_iter):
        z, sldj = net(x, cond_x+delta, reverse=False)
        latent_loss = loss_fn(z, sldj)
#         losses.append(latent_loss.item())
        latent_loss.backward()
        if norm=='l2':
            delta.data += alpha*delta.grad.detach() / norm_l2(delta.grad.detach())
            delta.data = torch.min(torch.max(delta.detach(), -X), 1-X) # clip X+delta to [0,1]
            delta.data *= epsilon / norm_l2(delta.detach()).clamp(min=epsilon)
        elif norm=='inf':
            delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
        delta.grad.zero_()
#     plt.plot(losses)
#     plt.show()

    if torch.mean(delta).isnan():
        print("delta is nan")
        delta=torch.zeros_like(X)
    return delta.detach()

In [7]:
@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm = -1, type = 'ct', args = None, model = None, epsilon = 1e-1):
    global global_step
    global_step = 0
    print('\nEpoch: %d' % epoch)
    net.train()
    latent_loss_m = util.AverageMeter()
    spatial_loss_m = util.AverageMeter()
    idx1, idx2 = get_idx(type)
    smooth_l1_loss = torch.nn.SmoothL1Loss().to(device)

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for i, x_prime in enumerate(trainloader):
            x = x_prime[0] #x_prime[:, idx1, :, :]
            cond_x = x_prime[1] #x_prime[:, idx2, :, :]
#             mask = x_prime[:, 4, :, :].unsqueeze(1).to(device)
#             mask = torch.where(mask > 0, 1, 0)
            if len(x.shape) < 4:
                x = x.unsqueeze(1)
            if len(cond_x.shape) < 4:
                cond_x = cond_x.unsqueeze(1)
            x, cond_x = x.to(device), cond_x.to(device)
#             cond_x.requires_grad = True

            # Adversarial 
            delta = pgd_linf_manual(net, x, cond_x, loss_fn, reverse=False)
            z, sldj = net(x, cond_x+delta, reverse=False)

            if args.ext == 'll+sl_pl' or args.ext == 'll+sl_l1':
                ;
            elif args.ext == 'll':
                optimizer.zero_grad()
                latent_loss = loss_fn(z, sldj)
                latent_loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
#                 scheduler.step(global_step)

                latent_loss_m.update(latent_loss.item(), x.size(0))
                progress_bar.set_postfix(nll=latent_loss_m.avg,
                                        bpd=util.bits_per_dim(x, latent_loss_m.avg),
                                        lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(x.size(0))
#                 global_step += x.size(0)


@torch.no_grad()
def test(epoch, net, testloader, device, args, path, history = []):
    global best_ssim
    global best_epoch
    net.eval()

    rrmse_val, psnr_val, ssim_val = evaluate_1c(net, testloader, device, args.mode)
    ssim = np.mean(ssim_val)
    flag = True
    history.append(ssim)

    # Save checkpoint
    if torch.isnan(torch.tensor(ssim)):
        return True

    if flag and ssim > best_ssim:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'ssim': ssim,
            'rrmse': np.mean(rrmse_val),
            'psnr': np.mean(psnr_val),
            'epoch': epoch,
        }
        path1 = '/'.join(path.split('/')[:-1])
        os.makedirs(path1, exist_ok=True)
        torch.save(state, path)        
        best_ssim = ssim
        best_epoch = epoch

    return False


In [None]:
loss_fn = util.NLLLoss(k=65535).to(args.device)

if 'sgd' in args.path:
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
elif 'adam' in args.path:
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
scheduler = None #sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up))

epoch = args.resume_training
while epoch <= args.epoch_num:
    train(epoch, net, trainloader, args.device, optimizer, scheduler,
          loss_fn, type = args.mode, args = args, model=unet)
    if (epoch+1)%25==0 and args.save:
        #save checkpoint
        torch.save(net.state_dict(), args.save_path+'/model_'+str(epoch+1)+'.ckpt')
    epoch += 1


Epoch: 0


100%|██████████| 200/200 [05:48<00:00,  1.74s/it, bpd=10.7, lr=0.001, nll=1.95e+6]



Epoch: 1


100%|██████████| 200/200 [05:41<00:00,  1.71s/it, bpd=10.3, lr=0.001, nll=1.88e+6]



Epoch: 2


100%|██████████| 200/200 [05:44<00:00,  1.72s/it, bpd=10.3, lr=0.001, nll=1.87e+6]



Epoch: 3


100%|██████████| 200/200 [05:47<00:00,  1.74s/it, bpd=10.3, lr=0.001, nll=1.87e+6]



Epoch: 4


100%|██████████| 200/200 [05:45<00:00,  1.73s/it, bpd=10.2, lr=0.001, nll=1.86e+6]



Epoch: 5


100%|██████████| 200/200 [05:48<00:00,  1.74s/it, bpd=10.2, lr=0.001, nll=1.85e+6]



Epoch: 6


100%|██████████| 200/200 [05:47<00:00,  1.74s/it, bpd=10.2, lr=0.001, nll=1.85e+6]



Epoch: 7


100%|██████████| 200/200 [05:46<00:00,  1.73s/it, bpd=10.2, lr=0.001, nll=1.85e+6]



Epoch: 8


100%|██████████| 200/200 [05:46<00:00,  1.73s/it, bpd=10.1, lr=0.001, nll=1.84e+6]



Epoch: 9


100%|██████████| 200/200 [05:48<00:00,  1.74s/it, bpd=10.1, lr=0.001, nll=1.84e+6]



Epoch: 10


100%|██████████| 200/200 [05:48<00:00,  1.74s/it, bpd=10.1, lr=0.001, nll=1.84e+6]



Epoch: 11


100%|██████████| 200/200 [05:50<00:00,  1.75s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 12


100%|██████████| 200/200 [05:47<00:00,  1.74s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 13


100%|██████████| 200/200 [05:53<00:00,  1.77s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 14


100%|██████████| 200/200 [06:03<00:00,  1.82s/it, bpd=10.2, lr=0.001, nll=1.85e+6]



Epoch: 15


100%|██████████| 200/200 [06:04<00:00,  1.82s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 16


100%|██████████| 200/200 [05:44<00:00,  1.72s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 17


100%|██████████| 200/200 [05:46<00:00,  1.73s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 18


100%|██████████| 200/200 [05:46<00:00,  1.73s/it, bpd=10.1, lr=0.001, nll=1.83e+6]



Epoch: 19


100%|██████████| 200/200 [05:46<00:00,  1.73s/it, bpd=10, lr=0.001, nll=1.83e+6]  



Epoch: 20


 64%|██████▎   | 127/200 [03:40<02:06,  1.74s/it, bpd=10, lr=0.001, nll=1.82e+6]  

In [None]:
errors = test_model(args,net,testloader)
print(np.median(errors,0))