In [1]:
import numpy as np
from tqdm import tqdm
from utils import *
from utils_data import *
import argparse
import os
import random
random.seed(0)

from models import Glow
from models.glow.coupling import UNet1
import util
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='ct')
parser.add_argument('--noise_level', type=list, default=[5000]) # For PET-CT, noise_level = [PET, CT]
parser.add_argument('--semi_sup', type=bool, default=True)
parser.add_argument('--supervision', type=float, default=0.0)
parser.add_argument('--secondary_noisy', type=int, default=0)
parser.add_argument('--resume_training', type=int, default=0)
parser.add_argument('--train_size', type=int, default=200)
parser.add_argument('--blur_mode',type=str, default=None)
parser.add_argument('--new_range',type=int, default=2)

parser.add_argument('--transfer_learning', type=bool, default=False)
parser.add_argument('--transfer_path', type=str, default='../results200_nd/unet_var_multi5/e3sgdws_petct_bpetnoperc_unet_var_ggg_multif_semi_0.0005-5000_0.5/model_400.ckpt')

parser.add_argument('--n_samples', type=int, default=1)
parser.add_argument('--s_samples', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--epoch_num', type=int, default=3000)
parser.add_argument('--lr', type=float, default=1e-2)
parser.add_argument('--device', type=str, default='cuda:1')
parser.add_argument('--perceptual', type=bool, default=False)
parser.add_argument('--weights',type=tuple,default=[1, 1, 1]) #(pet, ct, latent)
parser.add_argument('--target',type=str,default=None)

parser.add_argument('--save', type=bool, default=True)
parser.add_argument('--path', type=str, default='../results/e2sgd_')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path_fig', type=str, default='')

def str2bool(s):
    return s.lower().startswith('t')
parser.add_argument('--num_levels', '-L', default=4, type=int, help='Number of levels in the Glow model')
parser.add_argument('--num_steps', '-K', default=8, type=int, help='Number of steps of flow in each level')
parser.add_argument('--cc', type = str2bool, default = False)
parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up')
parser.add_argument('--ext', default = 'll', type=str)

args = parser.parse_args(args=[])
# args_check(args)

In [3]:
trainloader, testloader, validloader = load_data(args)

(200, 1, 512, 512) (326, 1, 512, 512)
same used
supervised
supervised


In [4]:
net = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
net = net.to(args.device)
# cudnn.benchmark = True
# if device == 'cuda':
#     net = torch.nn.DataParallel(net, args.gpu_ids)
#     cudnn.benchmark = args.benchmark

unet = UNet1(inp_channels=1, op_channels=1)
unet = unet.to(args.device)
# unet_weights = torch.load('ckpts/unet/best.pth', map_location = args.device)
# unet.load_state_dict(unet_weights)

In [5]:
@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm = -1, type = 'ct', args = None, model = None, epsilon = 1e-1):
    global global_step
    global_step = 0
    print('\nEpoch: %d' % epoch)
    net.train()
    latent_loss_m = util.AverageMeter()
    spatial_loss_m = util.AverageMeter()
    idx1, idx2 = get_idx(type)
    smooth_l1_loss = torch.nn.SmoothL1Loss().to(device)

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for i, x_prime in enumerate(trainloader):
            x = x_prime[1] #x_prime[:, idx1, :, :]
            cond_x = x_prime[0] #x_prime[:, idx2, :, :]
#             mask = x_prime[:, 4, :, :].unsqueeze(1).to(device)
#             mask = torch.where(mask > 0, 1, 0)
            if len(x.shape) < 4:
                x = x.unsqueeze(1)
            if len(cond_x.shape) < 4:
                cond_x = cond_x.unsqueeze(1)
            x, cond_x = x.to(device), cond_x.to(device)
            cond_x.requires_grad = True
            z, sldj = net(x, cond_x, reverse=False)

            if args.ext == 'll+sl_pl' or args.ext == 'll+sl_l1':
                sl_name = args.ext.split('_')[1]
                optimizer.zero_grad()
                latent_loss = loss_fn(z, sldj)
                new_z = torch.randn(x.shape, dtype=torch.float32, device=device) * 0.6
                rec_x, sldj = net(new_z, cond_x, reverse=True)
                rec_x = torch.sigmoid(rec_x) #mask * torch.sigmoid(rec_x)
                if sl_name == 'pl':
                    spatial_loss = perceptual_loss(rec_x, x, model, smooth_l1_loss)
                elif sl_name == 'l1':
                    spatial_loss = smooth_l1_loss(rec_x, x)

                loss = latent_loss + spatial_loss
                loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
                scheduler.step(global_step)

                latent_loss_m.update(latent_loss.item(), x.size(0))
                spatial_loss_m.update(spatial_loss.item(), x.size(0))
                progress_bar.set_postfix(bpd=util.bits_per_dim(x, latent_loss_m.avg),
                                         pl = spatial_loss_m.avg)
                progress_bar.update(x.size(0))
                global_step += x.size(0)

                # Adversarial Examples Training

                # Calculating FGSM
                cond_x = cond_x + (epsilon * torch.sign(cond_x.grad))
                cond_x = torch.clamp(cond_x, 0, 1)

                # Feeding to the network and calculating loss
                z, sldj = net(x, cond_x, reverse=False)
                optimizer.zero_grad()
                latent_loss = loss_fn(z, sldj)
                new_z = torch.randn(x.shape, dtype=torch.float32, device=device) * 0.6
                rec_x, sldj = net(new_z, cond_x, reverse=True)
                rec_x = torch.sigmoid(rec_x) #mask * torch.sigmoid(rec_x)
                if sl_name == 'pl':
                    spatial_loss = perceptual_loss(rec_x, x, model, smooth_l1_loss)
                elif sl_name == 'l1':
                    spatial_loss = smooth_l1_loss(rec_x, x)
                loss = latent_loss + spatial_loss

                # Backpropagation
                loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
                scheduler.step(global_step)

                # Updating the progress bar
                latent_loss_m.update(latent_loss.item(), x.size(0))
                spatial_loss_m.update(spatial_loss.item(), x.size(0))
                progress_bar.set_postfix(bpd=util.bits_per_dim(x, latent_loss_m.avg),
                                         pl = spatial_loss_m.avg)
                progress_bar.update(x.size(0))
                global_step += x.size(0)

            # elif args.ext == 'll_then_sl':
            #     optimizer.zero_grad()
            #     latent_loss = loss_fn(z, sldj)
            #     latent_loss.backward()
            #     if max_grad_norm > 0:
            #         util.clip_grad_norm(optimizer, max_grad_norm)
            #     optimizer.step()
            #     scheduler.step(global_step)

            #     new_z = torch.randn(x.shape, dtype=torch.float32, device=device) * 0.6
            #     rec_x, sldj = net(new_z, cond_x, reverse=True)
            #     rec_x = mask * torch.sigmoid(rec_x)
            #     optimizer.zero_grad()
            #     # l1_loss = smooth_l1_loss(rec_x, x)
            #     # ssim_loss = 1 - ssim1(rec_x, x, data_range = 1)
            #     # spatial_loss = l1_loss + ssim_loss
            #     spatial_loss = perceptual_loss(rec_x, x, model, smooth_l1_loss)
            #     spatial_loss.backward()
            #     if max_grad_norm > 0:
            #         util.clip_grad_norm(optimizer, max_grad_norm)
            #     optimizer.step()
            #     scheduler.step(global_step)
            #     print(f'll:{latent_loss}, sl:{spatial_loss}')
            #     latent_loss_m.update(latent_loss.item(), x.size(0))
            #     # l1_loss_m.update(l1_loss.item(), x.size(0))
            #     # ssim_loss_m.update(ssim_loss.item(), x.size(0))
            #     perceptual_loss_m.update(spatial_loss.item(), x.size(0))
            #     progress_bar.set_postfix(bpd=util.bits_per_dim(x, latent_loss_m.avg),
            #                              pl=perceptual_loss_m.avg)
            #                             # ssim=ssim_loss_m.avg,
            #                             # l1=l1_loss_m.avg)
            #     progress_bar.update(x.size(0))
            #     global_step += x.size(0)

            elif args.ext == 'll':
                optimizer.zero_grad()
                latent_loss = loss_fn(z, sldj)
                latent_loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
                scheduler.step(global_step)

                latent_loss_m.update(latent_loss.item(), x.size(0))
                progress_bar.set_postfix(nll=latent_loss_m.avg,
                                        bpd=util.bits_per_dim(x, latent_loss_m.avg),
                                        lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(x.size(0))
                global_step += x.size(0)

                # Adversarial Examples Training
                # fig, ax = plt.subplots(1, 3, figsize=(30, 30))
                # ax[0].imshow(cond_x[0,0,:,:].detach().cpu(), cmap='gray')
                # im = ax[1].imshow((epsilon * cond_x.grad)[0,0,:,:].detach().cpu(), cmap='gray')
                cond_x = cond_x + (epsilon * torch.sign(cond_x.grad))
                # ax[2].imshow(cond_x[0,0,:,:].detach().cpu(), cmap='gray')
                # ax[0].axis('off')
                # ax[1].axis('off')
                # ax[2].axis('off')
                # cbar = plt.colorbar(im, ax=ax, orientation = 'horizontal', 
                #  pad = 0.01, aspect = 100)
                # plt.savefig(f'test.png', bbox_inches='tight')
                cond_x = torch.clamp(cond_x, 0, 1)
                z, sldj = net(x, cond_x, reverse=False)
                optimizer.zero_grad()
                latent_loss = loss_fn(z, sldj)
                latent_loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
                scheduler.step(global_step)
                
        
                latent_loss_m.update(latent_loss.item(), x.size(0))
                progress_bar.set_postfix(nll=latent_loss_m.avg,
                                        bpd=util.bits_per_dim(x, latent_loss_m.avg),
                                        lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(x.size(0))
                global_step += x.size(0)



@torch.no_grad()
def test(epoch, net, testloader, device, args, path, history = []):
    global best_ssim
    global best_epoch
    net.eval()

    rrmse_val, psnr_val, ssim_val = evaluate_1c(net, testloader, device, args.type)
    ssim = np.mean(ssim_val)
    flag = True
    history.append(ssim)

    # Save checkpoint
    if torch.isnan(torch.tensor(ssim)):
        return True

    if flag and ssim > best_ssim:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'ssim': ssim,
            'rrmse': np.mean(rrmse_val),
            'psnr': np.mean(psnr_val),
            'epoch': epoch,
        }
        path1 = '/'.join(path.split('/')[:-1])
        os.makedirs(path1, exist_ok=True)
        torch.save(state, path)        
        best_ssim = ssim
        best_epoch = epoch

    return False

In [6]:
loss_fn = util.NLLLoss().to(args.device)
# loss_fn = util.NLLLoss(shape = args.shape, device = device).to(device)
optimizer = optim.Adam(net.parameters(), lr=args.lr)
scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up))
start_epoch = 0

epoch = start_epoch
c = 0
history = []
while epoch <= args.epoch_num:
    c += 1
    train(epoch, net, trainloader, args.device, optimizer, scheduler,
          loss_fn, type = args.mode, args = args, model=unet)
    if test(epoch, net, testloader, device, args, path, history):
        if os.path.exists(path):  
            checkpoint = torch.load(path, map_location = device)
            net.load_state_dict(checkpoint['net'])
            best_ssim = checkpoint['ssim']
            best_epoch = checkpoint['epoch']
            epoch = best_epoch
            print('Loaded previous model...')
        else:
            net = Glow(num_channels=args.num_channels,
                        num_levels=args.num_levels,
                        num_steps=args.num_steps,
                        mode=args.mode,
                        inp_channel=args.inp_channel,
                        cond_channel=args.cond_channel,
                        cc = args.cc)
            net = net.to(device)
            if device == 'cuda':
                net = torch.nn.DataParallel(net, args.gpu_ids)
                cudnn.benchmark = args.benchmark
            best_ssim = 0
            best_epoch = start_epoch
            epoch = start_epoch
            print('Initialized new model!')

        optimizer = optim.Adam(net.parameters(), lr=args.lr)
        scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up))
    else:
        epoch += 1



Epoch: 0


  4%|▎         | 7/200 [00:09<04:08,  1.29s/it, bpd=nan, lr=1.2e-7, nll=nan]   


RuntimeError: svd_cuda: For batch 0: U(257,257) is zero, singular U.