In [1]:
import numpy as np
from tqdm import tqdm
from utils import *
from utils_data import *
import argparse
import os
import random
random.seed(0)

from models import Glow
from models.glow.coupling import UNet1
import util
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='ct')
parser.add_argument('--noise_level', type=list, default=[5000]) # For PET-CT, noise_level = [PET, CT]
parser.add_argument('--semi_sup', type=bool, default=True)
parser.add_argument('--supervision', type=float, default=0.0)
parser.add_argument('--secondary_noisy', type=bool, default=True)
parser.add_argument('--resume_training', type=int, default=300)
parser.add_argument('--train_size', type=int, default=200)
parser.add_argument('--blur_mode',type=str, default=None)
parser.add_argument('--new_range',type=int, default=2)

parser.add_argument('--transfer_learning', type=bool, default=False)
parser.add_argument('--transfer_path', type=str, default='../results200_nd/unet_var_multi5/e3sgdws_petct_bpetnoperc_unet_var_ggg_multif_semi_0.0005-5000_0.5/model_400.ckpt')

parser.add_argument('--n_samples', type=int, default=1)
parser.add_argument('--s_samples', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--epoch_num', type=int, default=500)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--device', type=str, default='cuda:1')
parser.add_argument('--perceptual', type=bool, default=False)
parser.add_argument('--weights',type=tuple,default=[1, 1, 1]) #(pet, ct, latent)
parser.add_argument('--target',type=str,default=None)

parser.add_argument('--save', type=bool, default=True)
parser.add_argument('--path', type=str, default='../results/e3adam_')
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path_fig', type=str, default='')

def str2bool(s):
    return s.lower().startswith('t')
parser.add_argument('--num_levels', '-L', default=4, type=int, help='Number of levels in the Glow model')
parser.add_argument('--num_steps', '-K', default=8, type=int, help='Number of steps of flow in each level')
parser.add_argument('--cc', type = str2bool, default = False)
parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up')
parser.add_argument('--ext', default = 'll', type=str)

args = parser.parse_args(args=[])
# args_check(args)

In [3]:
trainloader, testloader, validloader = load_data(args)

(200, 1, 512, 512) (326, 1, 512, 512)
1 secondary taget used


In [4]:
args = create_save_path(args)

../results/e3adam_ct_semi_5000_0.0


In [5]:
net = Glow(num_channels=1,
               num_levels=args.num_levels,
               num_steps=args.num_steps,
               inp_channel=1,
               cond_channel=1,
               cc = args.cc)
net = net.to(args.device)
cudnn.benchmark = True
# if device == 'cuda':
#     net = torch.nn.DataParallel(net, args.gpu_ids)
#     cudnn.benchmark = args.benchmark

if args.resume_training!=0:
    net.load_state_dict(torch.load(args.save_path+'/model_'+str(args.resume_training)+'.ckpt', map_location='cpu'))

unet = UNet1(inp_channels=1, op_channels=1)
unet = unet.to(args.device)
# unet_weights = torch.load('ckpts/unet/best.pth', map_location = args.device)
# unet.load_state_dict(unet_weights)

In [6]:
@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm = -1, type = 'ct', args = None, model = None, epsilon = 1e-1):
    global global_step
    global_step = 0
    print('\nEpoch: %d' % epoch)
    net.train()
    latent_loss_m = util.AverageMeter()
    spatial_loss_m = util.AverageMeter()
    idx1, idx2 = get_idx(type)
    smooth_l1_loss = torch.nn.SmoothL1Loss().to(device)

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for i, x_prime in enumerate(trainloader):
            x = x_prime[0] #x_prime[:, idx1, :, :]
            cond_x = x_prime[1] #x_prime[:, idx2, :, :]
#             mask = x_prime[:, 4, :, :].unsqueeze(1).to(device)
#             mask = torch.where(mask > 0, 1, 0)
            if len(x.shape) < 4:
                x = x.unsqueeze(1)
            if len(cond_x.shape) < 4:
                cond_x = cond_x.unsqueeze(1)
            x, cond_x = x.to(device), cond_x.to(device)
#             cond_x.requires_grad = True
            z, sldj = net(x, cond_x, reverse=False)

            if args.ext == 'll+sl_pl' or args.ext == 'll+sl_l1':
                ;
            elif args.ext == 'll':
#                 print(optimizer.param_groups[0]['lr'])
                optimizer.zero_grad()
                latent_loss = loss_fn(z, sldj)
                latent_loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
#                 scheduler.step(global_step)

                latent_loss_m.update(latent_loss.item(), x.size(0))
                progress_bar.set_postfix(nll=latent_loss_m.avg,
                                        bpd=util.bits_per_dim(x, latent_loss_m.avg),
                                        lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(x.size(0))
#                 global_step += x.size(0)


@torch.no_grad()
def test(epoch, net, testloader, device, args, path, history = []):
    global best_ssim
    global best_epoch
    net.eval()

    rrmse_val, psnr_val, ssim_val = evaluate_1c(net, testloader, device, args.mode)
    ssim = np.mean(ssim_val)
    flag = True
    history.append(ssim)

    # Save checkpoint
    if torch.isnan(torch.tensor(ssim)):
        return True

    if flag and ssim > best_ssim:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'ssim': ssim,
            'rrmse': np.mean(rrmse_val),
            'psnr': np.mean(psnr_val),
            'epoch': epoch,
        }
        path1 = '/'.join(path.split('/')[:-1])
        os.makedirs(path1, exist_ok=True)
        torch.save(state, path)        
        best_ssim = ssim
        best_epoch = epoch

    return False

In [None]:
loss_fn = util.NLLLoss(k=65535).to(args.device)

if 'sgd' in args.path:
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
elif 'adam' in args.path:
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
scheduler = None #sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up))

epoch = args.resume_training
history = []
while epoch < args.epoch_num:
    train(epoch, net, trainloader, args.device, optimizer, scheduler,
          loss_fn, type = args.mode, args = args, model=unet)
    if (epoch+1)%25==0 and args.save:
        #save checkpoint
        torch.save(net.state_dict(), args.save_path+'/model_'+str(epoch+1)+'.ckpt')
    epoch += 1


Epoch: 300


100%|██████████| 200/200 [00:56<00:00,  3.57it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 301


100%|██████████| 200/200 [00:54<00:00,  3.64it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 302


100%|██████████| 200/200 [00:54<00:00,  3.64it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 303


100%|██████████| 200/200 [00:57<00:00,  3.46it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 304


100%|██████████| 200/200 [00:53<00:00,  3.75it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 305


100%|██████████| 200/200 [00:53<00:00,  3.75it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 306


100%|██████████| 200/200 [00:54<00:00,  3.67it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 307


100%|██████████| 200/200 [00:54<00:00,  3.70it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 308


100%|██████████| 200/200 [00:54<00:00,  3.69it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 309


100%|██████████| 200/200 [00:55<00:00,  3.62it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 310


100%|██████████| 200/200 [00:53<00:00,  3.72it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 311


100%|██████████| 200/200 [00:54<00:00,  3.67it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 312


100%|██████████| 200/200 [00:54<00:00,  3.64it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 313


100%|██████████| 200/200 [00:54<00:00,  3.65it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 314


100%|██████████| 200/200 [00:54<00:00,  3.67it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 315


100%|██████████| 200/200 [00:54<00:00,  3.69it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 316


100%|██████████| 200/200 [00:54<00:00,  3.67it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 317


100%|██████████| 200/200 [00:54<00:00,  3.66it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 318


100%|██████████| 200/200 [00:55<00:00,  3.63it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 319


100%|██████████| 200/200 [00:57<00:00,  3.51it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 320


100%|██████████| 200/200 [00:58<00:00,  3.43it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 321


100%|██████████| 200/200 [00:57<00:00,  3.48it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 322


100%|██████████| 200/200 [00:59<00:00,  3.35it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 323


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 324


100%|██████████| 200/200 [01:04<00:00,  3.11it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 325


100%|██████████| 200/200 [01:05<00:00,  3.07it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 326


100%|██████████| 200/200 [01:08<00:00,  2.91it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 327


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 328


100%|██████████| 200/200 [01:06<00:00,  3.02it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 329


100%|██████████| 200/200 [01:09<00:00,  2.86it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 330


100%|██████████| 200/200 [01:08<00:00,  2.94it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 331


100%|██████████| 200/200 [01:06<00:00,  3.03it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 332


100%|██████████| 200/200 [01:07<00:00,  2.95it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 333


100%|██████████| 200/200 [01:07<00:00,  2.97it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 334


100%|██████████| 200/200 [01:08<00:00,  2.92it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 335


100%|██████████| 200/200 [01:09<00:00,  2.87it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 336


100%|██████████| 200/200 [01:08<00:00,  2.94it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 337


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 338


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 339


100%|██████████| 200/200 [01:06<00:00,  3.01it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 340


100%|██████████| 200/200 [01:05<00:00,  3.03it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 341


100%|██████████| 200/200 [01:04<00:00,  3.08it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 342


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 343


100%|██████████| 200/200 [01:05<00:00,  3.07it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 344


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 345


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 346


100%|██████████| 200/200 [01:06<00:00,  3.02it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 347


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 348


100%|██████████| 200/200 [01:05<00:00,  3.03it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 349


100%|██████████| 200/200 [01:05<00:00,  3.06it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 350


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 351


100%|██████████| 200/200 [01:05<00:00,  3.07it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 352


100%|██████████| 200/200 [01:04<00:00,  3.09it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 353


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 354


100%|██████████| 200/200 [01:06<00:00,  2.99it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 355


100%|██████████| 200/200 [01:06<00:00,  3.02it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 356


100%|██████████| 200/200 [01:05<00:00,  3.06it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 357


100%|██████████| 200/200 [01:06<00:00,  3.01it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 358


100%|██████████| 200/200 [01:06<00:00,  3.01it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 359


100%|██████████| 200/200 [01:06<00:00,  3.00it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 360


100%|██████████| 200/200 [01:08<00:00,  2.92it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 361


100%|██████████| 200/200 [01:08<00:00,  2.92it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 362


100%|██████████| 200/200 [01:05<00:00,  3.05it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 363


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 364


100%|██████████| 200/200 [01:07<00:00,  2.95it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 365


100%|██████████| 200/200 [01:09<00:00,  2.86it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 366


100%|██████████| 200/200 [01:09<00:00,  2.88it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 367


100%|██████████| 200/200 [01:10<00:00,  2.85it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 368


100%|██████████| 200/200 [01:07<00:00,  2.96it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 369


100%|██████████| 200/200 [01:07<00:00,  2.96it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 370


100%|██████████| 200/200 [01:08<00:00,  2.92it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 371


100%|██████████| 200/200 [01:09<00:00,  2.89it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 372


100%|██████████| 200/200 [01:11<00:00,  2.80it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 373


100%|██████████| 200/200 [01:09<00:00,  2.90it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 374


100%|██████████| 200/200 [01:09<00:00,  2.89it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 375


100%|██████████| 200/200 [01:10<00:00,  2.84it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 376


100%|██████████| 200/200 [01:10<00:00,  2.85it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 377


100%|██████████| 200/200 [01:09<00:00,  2.88it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 378


100%|██████████| 200/200 [01:09<00:00,  2.89it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 379


100%|██████████| 200/200 [01:09<00:00,  2.87it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 380


100%|██████████| 200/200 [01:08<00:00,  2.91it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 381


100%|██████████| 200/200 [01:11<00:00,  2.80it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 382


100%|██████████| 200/200 [01:09<00:00,  2.86it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 383


100%|██████████| 200/200 [01:11<00:00,  2.79it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 384


100%|██████████| 200/200 [01:10<00:00,  2.82it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 385


100%|██████████| 200/200 [01:11<00:00,  2.78it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 386


100%|██████████| 200/200 [01:09<00:00,  2.86it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 387


100%|██████████| 200/200 [01:09<00:00,  2.88it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 388


100%|██████████| 200/200 [01:09<00:00,  2.87it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 389


100%|██████████| 200/200 [01:09<00:00,  2.88it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 390


100%|██████████| 200/200 [01:07<00:00,  2.95it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 391


100%|██████████| 200/200 [01:10<00:00,  2.86it/s, bpd=9.9, lr=0.001, nll=1.8e+6] 



Epoch: 392


100%|██████████| 200/200 [01:08<00:00,  2.92it/s, bpd=9.89, lr=0.001, nll=1.8e+6]



Epoch: 393


100%|██████████| 200/200 [01:07<00:00,  2.96it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 394


100%|██████████| 200/200 [01:08<00:00,  2.91it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 395


100%|██████████| 200/200 [01:05<00:00,  3.07it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 396


100%|██████████| 200/200 [01:06<00:00,  3.02it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 397


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 398


100%|██████████| 200/200 [01:07<00:00,  2.94it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 399


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 400


100%|██████████| 200/200 [01:07<00:00,  2.97it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 401


100%|██████████| 200/200 [01:06<00:00,  2.99it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 402


100%|██████████| 200/200 [01:04<00:00,  3.10it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 403


100%|██████████| 200/200 [01:08<00:00,  2.94it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 404


100%|██████████| 200/200 [01:05<00:00,  3.03it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 405


100%|██████████| 200/200 [01:07<00:00,  2.95it/s, bpd=9.9, lr=0.001, nll=1.8e+6]  



Epoch: 406


100%|██████████| 200/200 [01:11<00:00,  2.81it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 407


100%|██████████| 200/200 [01:07<00:00,  2.97it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 408


100%|██████████| 200/200 [01:06<00:00,  3.00it/s, bpd=9.89, lr=0.001, nll=1.8e+6]



Epoch: 409


100%|██████████| 200/200 [01:06<00:00,  3.01it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 410


100%|██████████| 200/200 [01:05<00:00,  3.04it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 411


100%|██████████| 200/200 [01:06<00:00,  2.99it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 412


100%|██████████| 200/200 [01:07<00:00,  2.98it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 413


100%|██████████| 200/200 [01:04<00:00,  3.09it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 414


100%|██████████| 200/200 [01:07<00:00,  2.96it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 415


100%|██████████| 200/200 [01:06<00:00,  3.02it/s, bpd=9.89, lr=0.001, nll=1.8e+6]



Epoch: 416


100%|██████████| 200/200 [01:05<00:00,  3.03it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 417


100%|██████████| 200/200 [01:07<00:00,  2.98it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 418


100%|██████████| 200/200 [01:05<00:00,  3.06it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 419


100%|██████████| 200/200 [01:05<00:00,  3.03it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 420


100%|██████████| 200/200 [01:06<00:00,  3.02it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 421


100%|██████████| 200/200 [01:06<00:00,  3.00it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 422


100%|██████████| 200/200 [01:07<00:00,  2.98it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 423


100%|██████████| 200/200 [01:06<00:00,  3.00it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 424


100%|██████████| 200/200 [01:10<00:00,  2.84it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 425


100%|██████████| 200/200 [01:10<00:00,  2.84it/s, bpd=9.89, lr=0.001, nll=1.8e+6]



Epoch: 426


100%|██████████| 200/200 [01:09<00:00,  2.87it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 427


100%|██████████| 200/200 [01:10<00:00,  2.83it/s, bpd=9.89, lr=0.001, nll=1.8e+6] 



Epoch: 428


 68%|██████▊   | 136/200 [00:48<00:21,  3.04it/s, bpd=9.91, lr=0.001, nll=1.8e+6]

In [None]:
errors = test_model(args,net,testloader)
print(np.median(errors,0))