In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from model import *
from loss import *
from sync_batchnorm import convert_model

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 6
BATCH_SIZE = 3

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


* class down_unit
    * pass
* class up_unit
    * pass
* class input_unit
    * pass
* class output_unit
    * pass

In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
# loaders come with auto batch division and multi-thread acceleration

In [4]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv3d):
        init.kaiming_normal_(m.weight.data)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=0, std=1)

def downsample_label(label, scale_factor):
    return F.interpolate(label, scale_factor=scale_factor, mode='trilinear', align_corners=True)

In [5]:
test_dictionary = train_dataset[33]

image_4 = test_dictionary['image4_data'].view(1, 1, 64, 64, 64)
label_4 = test_dictionary['image4_label'].view(1, 3, 64, 64, 64)

image_2 = test_dictionary['image2_data'].view(1, 1, 128, 128, 128)
label_2 = test_dictionary['image2_label'].view(1, 3, 128, 128, 128)

image_1 = test_dictionary['image1_data'].view(1, 1, 256, 256, 256)
label_1 = test_dictionary['image1_label'].view(1, 3, 256, 256, 256)

label_1_resize_2 = downsample_label(label_1, 1/2) 
label_2_resize_2 = downsample_label(label_2, 1/2)
label_4_resize_2 = downsample_label(label_4, 1/2)

label_1_resize_4 = downsample_label(label_1, 1/4) 
label_2_resize_4 = downsample_label(label_2, 1/4)
label_4_resize_4 = downsample_label(label_4, 1/4)

image_4 = image_4.to(device=device, dtype=dtype)  # move to device, fix dtype
label_4 = label_4.to(device=device, dtype=dtype)
label_4_resize_4 = label_4_resize_4.to(device=device, dtype=dtype)
label_4_resize_2 = label_4_resize_2.to(device=device, dtype=dtype)

image_2 = image_2.to(device=device, dtype=dtype)
label_2 = label_2.to(device=device, dtype=dtype)
label_2_resize_4 = label_2_resize_4.to(device=device, dtype=dtype)
label_2_resize_2 = label_2_resize_2.to(device=device, dtype=dtype)

image_1 = image_1.to(device=device, dtype=dtype) 
label_1 = label_1.to(device=device, dtype=dtype)
label_1_resize_4 = label_1_resize_4.to(device=device, dtype=dtype)
label_1_resize_2 = label_1_resize_2.to(device=device, dtype=dtype)

print("Label 1:", label_1.shape, "Downsampled 1/4:", label_1_resize_4.shape, "Downsampled 1/2:", label_1_resize_2.shape)
print("Label 2:", label_2.shape, "Downsampled 1/4:", label_2_resize_4.shape, "Downsampled 1/2:", label_2_resize_2.shape)
print("Label 4:", label_4.shape, "Downsampled 1/4:", label_4_resize_4.shape, "Downsampled 1/2:", label_4_resize_2.shape)

from model import *

icnet1 = ModifiedICNet(num_classes=3)
icnet1.apply(init_weights)
icnet1 = icnet1.to(device=device, dtype=dtype)

icnet2 = OriginalICNet(num_classes=3)
icnet2.apply(init_weights)
icnet2 = icnet2.to(device=device, dtype=dtype)

full_res_icnet = FullResolutionICNet(num_classes=3)
full_res_icnet = nn.DataParallel(full_res_icnet)
full_res_icnet = convert_model(full_res_icnet)
full_res_icnet.apply(init_weights)
full_res_icnet = full_res_icnet.to(device=device, dtype=dtype)

import torch.optim as optim

optimizer1 = optim.Adam(icnet1.parameters(), lr=1e-2)
optimizer2 = optim.Adam(icnet2.parameters(), lr=1e-2)
optimizer3 = optim.Adam(full_res_icnet.parameters(), lr=1e-2)

Label 1: torch.Size([1, 3, 256, 256, 256]) Downsampled 1/4: torch.Size([1, 3, 64, 64, 64]) Downsampled 1/2: torch.Size([1, 3, 128, 128, 128])
Label 2: torch.Size([1, 3, 128, 128, 128]) Downsampled 1/4: torch.Size([1, 3, 32, 32, 32]) Downsampled 1/2: torch.Size([1, 3, 64, 64, 64])
Label 4: torch.Size([1, 3, 64, 64, 64]) Downsampled 1/4: torch.Size([1, 3, 16, 16, 16]) Downsampled 1/2: torch.Size([1, 3, 32, 32, 32])


In [6]:
from model import *

def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    scores = model(x)
    for i in scores:
        print(i.size())

m = FullResolutionICNet(num_classes=3)
m = nn.DataParallel(m)
m = convert_model(m)
m = m.to(device=device, dtype=dtype)
shape_test(m, True)

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
torch.Size([1, 3, 256, 256, 256])
torch.Size([1, 3, 128, 128, 128])
torch.Size([1, 3, 64, 64, 64])


* network
    * test with GPU

In [7]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_full_res_model_check2.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1, out_2, out_4 = full_res_icnet(image_1)
        
    #loss_4 = dice_loss_3(out_4, label_4_resize_2)
    #loss_2 = dice_loss_3(out_2, label_2_resize_2)
    loss_4 = dice_loss_3(out_4, label_4)
    loss_2 = dice_loss_3(out_2, label_2)
    loss_1 = dice_loss_3(out_1, label_1)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1 + loss_2 + loss_4
    
    outstr = 'in epoch {}, loss = {}, loss_1: {}, loss_2: {}'.format(e, loss.item(), loss_1.item(), loss_2.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer3.zero_grad()
    loss.backward()
    optimizer3.step()

record.close()

  0%|          | 0/5000 [00:00<?, ?it/s]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 0, loss = 2.30322265625, loss_1: 0.7249848246574402, loss_2: 0.7953076362609863



  0%|          | 1/5000 [01:38<136:25:54, 98.25s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 1, loss = 2.22186279296875, loss_1: 0.59965580701828, loss_2: 0.8397981524467468



  0%|          | 2/5000 [01:42<97:04:33, 69.92s/it] 

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 2, loss = 2.0157437324523926, loss_1: 0.5156428217887878, loss_2: 0.7972216606140137



  0%|          | 3/5000 [01:45<69:31:12, 50.08s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 3, loss = 1.8818776607513428, loss_1: 0.4659937620162964, loss_2: 0.7419031858444214



  0%|          | 4/5000 [01:49<50:14:22, 36.20s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 4, loss = 1.7090258598327637, loss_1: 0.43107056617736816, loss_2: 0.6668326258659363



  0%|          | 5/5000 [01:53<36:44:33, 26.48s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 5, loss = 1.5841742753982544, loss_1: 0.39967817068099976, loss_2: 0.6029988527297974



  0%|          | 6/5000 [01:57<27:17:47, 19.68s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 6, loss = 1.4926862716674805, loss_1: 0.37950628995895386, loss_2: 0.5707460641860962



  0%|          | 7/5000 [02:01<20:41:13, 14.92s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 7, loss = 1.3893029689788818, loss_1: 0.3402247130870819, loss_2: 0.5225376486778259



  0%|          | 8/5000 [02:04<16:03:47, 11.58s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 8, loss = 1.2829465866088867, loss_1: 0.2916432321071625, loss_2: 0.49229249358177185



  0%|          | 9/5000 [02:08<12:49:21,  9.25s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 9, loss = 1.2077640295028687, loss_1: 0.25596678256988525, loss_2: 0.4691804051399231



  0%|          | 10/5000 [02:12<10:33:27,  7.62s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 10, loss = 1.1703203916549683, loss_1: 0.24922893941402435, loss_2: 0.4547489583492279



  0%|          | 11/5000 [02:16<8:58:09,  6.47s/it] 

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 11, loss = 1.18222975730896, loss_1: 0.275799036026001, loss_2: 0.4504990577697754



  0%|          | 12/5000 [02:20<7:51:40,  5.67s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 12, loss = 1.2067351341247559, loss_1: 0.25174659490585327, loss_2: 0.44866085052490234



  0%|          | 13/5000 [02:23<7:04:52,  5.11s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 13, loss = 1.2162026166915894, loss_1: 0.27545973658561707, loss_2: 0.4439317584037781



  0%|          | 14/5000 [02:27<6:32:19,  4.72s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 14, loss = 1.0758006572723389, loss_1: 0.19547638297080994, loss_2: 0.4305034279823303



  0%|          | 15/5000 [02:31<6:09:26,  4.45s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 15, loss = 1.1287128925323486, loss_1: 0.23777352273464203, loss_2: 0.44032782316207886



  0%|          | 16/5000 [02:35<5:53:26,  4.25s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 16, loss = 1.0790328979492188, loss_1: 0.2142048180103302, loss_2: 0.4240497946739197



  0%|          | 17/5000 [02:39<5:42:06,  4.12s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 17, loss = 1.0593526363372803, loss_1: 0.2002461850643158, loss_2: 0.42319056391716003



  0%|          | 18/5000 [02:42<5:34:10,  4.02s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 18, loss = 1.0607211589813232, loss_1: 0.2055160105228424, loss_2: 0.41932451725006104



  0%|          | 19/5000 [02:46<5:28:36,  3.96s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 19, loss = 1.0070762634277344, loss_1: 0.15747910737991333, loss_2: 0.4155387878417969



  0%|          | 20/5000 [02:50<5:24:43,  3.91s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 20, loss = 1.0057953596115112, loss_1: 0.16572800278663635, loss_2: 0.4093037545681



  0%|          | 21/5000 [02:54<5:21:50,  3.88s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 21, loss = 0.9424102306365967, loss_1: 0.11523537337779999, loss_2: 0.4016411304473877



  0%|          | 22/5000 [02:58<5:20:00,  3.86s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 22, loss = 0.9450538158416748, loss_1: 0.1293540596961975, loss_2: 0.396735280752182



  0%|          | 23/5000 [03:01<5:18:41,  3.84s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 23, loss = 0.9192048907279968, loss_1: 0.11326058954000473, loss_2: 0.39393743872642517



  0%|          | 24/5000 [03:05<5:17:39,  3.83s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 24, loss = 0.9109774231910706, loss_1: 0.10403329133987427, loss_2: 0.39534497261047363



  0%|          | 25/5000 [03:09<5:16:50,  3.82s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 25, loss = 0.8995977640151978, loss_1: 0.09992873668670654, loss_2: 0.39267730712890625



  1%|          | 26/5000 [03:13<5:16:25,  3.82s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 26, loss = 0.8793225288391113, loss_1: 0.08772393316030502, loss_2: 0.3914278447628021



  1%|          | 27/5000 [03:17<5:16:03,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 27, loss = 0.8741148114204407, loss_1: 0.08749119937419891, loss_2: 0.38895031809806824



  1%|          | 28/5000 [03:20<5:15:45,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 28, loss = 0.8664515614509583, loss_1: 0.08389656245708466, loss_2: 0.3880535066127777



  1%|          | 29/5000 [03:24<5:15:26,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 29, loss = 0.8540283441543579, loss_1: 0.0772559866309166, loss_2: 0.3855873942375183



  1%|          | 30/5000 [03:28<5:15:20,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 30, loss = 0.8481205701828003, loss_1: 0.07677119970321655, loss_2: 0.3815661072731018



  1%|          | 31/5000 [03:32<5:14:58,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 31, loss = 0.8402727246284485, loss_1: 0.07400014251470566, loss_2: 0.3788529336452484



  1%|          | 32/5000 [03:36<5:14:58,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 32, loss = 0.831584095954895, loss_1: 0.07084157317876816, loss_2: 0.37621045112609863



  1%|          | 33/5000 [03:40<5:15:46,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 33, loss = 0.8269290924072266, loss_1: 0.0698544979095459, loss_2: 0.3732141852378845



  1%|          | 34/5000 [03:43<5:15:25,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 34, loss = 0.8243364095687866, loss_1: 0.06991889327764511, loss_2: 0.371340274810791



  1%|          | 35/5000 [03:47<5:15:03,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 35, loss = 0.8200473785400391, loss_1: 0.06951062381267548, loss_2: 0.37080878019332886



  1%|          | 36/5000 [03:51<5:14:50,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 36, loss = 0.8152350187301636, loss_1: 0.06769603490829468, loss_2: 0.37052100896835327



  1%|          | 37/5000 [03:55<5:14:40,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 37, loss = 0.8116481304168701, loss_1: 0.06681627035140991, loss_2: 0.3690960109233856



  1%|          | 38/5000 [03:59<5:14:25,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 38, loss = 0.8098260164260864, loss_1: 0.06727530807256699, loss_2: 0.36772650480270386



  1%|          | 39/5000 [04:02<5:14:09,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 39, loss = 0.8057270050048828, loss_1: 0.06596241891384125, loss_2: 0.36602944135665894



  1%|          | 40/5000 [04:06<5:14:03,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 40, loss = 0.8013460636138916, loss_1: 0.06471316516399384, loss_2: 0.364091157913208



  1%|          | 41/5000 [04:10<5:14:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 41, loss = 0.7974642515182495, loss_1: 0.06413024663925171, loss_2: 0.36185213923454285



  1%|          | 42/5000 [04:14<5:14:07,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 42, loss = 0.7941405773162842, loss_1: 0.06423243135213852, loss_2: 0.3596753478050232



  1%|          | 43/5000 [04:18<5:14:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 43, loss = 0.7900241613388062, loss_1: 0.06344479322433472, loss_2: 0.35707974433898926



  1%|          | 44/5000 [04:21<5:13:52,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 44, loss = 0.784609854221344, loss_1: 0.06266944110393524, loss_2: 0.35378730297088623



  1%|          | 45/5000 [04:25<5:13:44,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 45, loss = 0.7763715386390686, loss_1: 0.06216862052679062, loss_2: 0.3475392758846283



  1%|          | 46/5000 [04:29<5:13:37,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 46, loss = 0.7623739242553711, loss_1: 0.06143087148666382, loss_2: 0.3361632823944092



  1%|          | 47/5000 [04:33<5:13:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 47, loss = 0.742804229259491, loss_1: 0.06078042834997177, loss_2: 0.31976932287216187



  1%|          | 48/5000 [04:37<5:13:27,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 48, loss = 0.71051424741745, loss_1: 0.060690879821777344, loss_2: 0.29524821043014526



  1%|          | 49/5000 [04:40<5:13:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 49, loss = 0.6787656545639038, loss_1: 0.06328563392162323, loss_2: 0.2789269983768463



  1%|          | 50/5000 [04:44<5:13:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 50, loss = 0.7218616008758545, loss_1: 0.07153453677892685, loss_2: 0.30939409136772156



  1%|          | 51/5000 [04:48<5:13:04,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 51, loss = 0.8638702630996704, loss_1: 0.16234338283538818, loss_2: 0.3034725487232208



  1%|          | 52/5000 [04:52<5:13:08,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 52, loss = 0.6917452812194824, loss_1: 0.06769188493490219, loss_2: 0.2749553322792053



  1%|          | 53/5000 [04:55<5:12:58,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 53, loss = 0.7134805917739868, loss_1: 0.11058847606182098, loss_2: 0.2791929841041565



  1%|          | 54/5000 [04:59<5:12:46,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 54, loss = 0.6517822742462158, loss_1: 0.06882821023464203, loss_2: 0.265737384557724



  1%|          | 55/5000 [05:03<5:12:49,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 55, loss = 0.663500189781189, loss_1: 0.08297568559646606, loss_2: 0.2710742652416229



  1%|          | 56/5000 [05:07<5:12:44,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 56, loss = 0.6386587619781494, loss_1: 0.08197443187236786, loss_2: 0.24694833159446716



  1%|          | 57/5000 [05:11<5:12:46,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 57, loss = 0.6169276237487793, loss_1: 0.06829269975423813, loss_2: 0.254932165145874



  1%|          | 58/5000 [05:14<5:12:40,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 58, loss = 0.6016936302185059, loss_1: 0.0763811320066452, loss_2: 0.23726198077201843



  1%|          | 59/5000 [05:18<5:12:36,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 59, loss = 0.5950159430503845, loss_1: 0.06962031126022339, loss_2: 0.23902110755443573



  1%|          | 60/5000 [05:22<5:12:23,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 60, loss = 0.5836327075958252, loss_1: 0.06417463719844818, loss_2: 0.23451998829841614



  1%|          | 61/5000 [05:26<5:12:23,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 61, loss = 0.5753986239433289, loss_1: 0.0640358179807663, loss_2: 0.2331523597240448



  1%|          | 62/5000 [05:30<5:12:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 62, loss = 0.5671383142471313, loss_1: 0.060849547386169434, loss_2: 0.2249603271484375



  1%|▏         | 63/5000 [05:33<5:12:17,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 63, loss = 0.5573612451553345, loss_1: 0.05773963779211044, loss_2: 0.22085782885551453



  1%|▏         | 64/5000 [05:37<5:12:09,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 64, loss = 0.558944821357727, loss_1: 0.057869236916303635, loss_2: 0.218438982963562



  1%|▏         | 65/5000 [05:41<5:12:11,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 65, loss = 0.5462599992752075, loss_1: 0.0554162859916687, loss_2: 0.2163371741771698



  1%|▏         | 66/5000 [05:45<5:12:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 66, loss = 0.536476194858551, loss_1: 0.053754907101392746, loss_2: 0.21190230548381805



  1%|▏         | 67/5000 [05:49<5:11:59,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 67, loss = 0.538582444190979, loss_1: 0.05326930806040764, loss_2: 0.2139412760734558



  1%|▏         | 68/5000 [05:52<5:12:04,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 68, loss = 0.5312248468399048, loss_1: 0.0543377622961998, loss_2: 0.20955699682235718



  1%|▏         | 69/5000 [05:56<5:11:57,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 69, loss = 0.5224759578704834, loss_1: 0.05301646515727043, loss_2: 0.20533156394958496



  1%|▏         | 70/5000 [06:00<5:12:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 70, loss = 0.5179582834243774, loss_1: 0.052519284188747406, loss_2: 0.2053774893283844



  1%|▏         | 71/5000 [06:04<5:11:54,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 71, loss = 0.5162862539291382, loss_1: 0.05138544365763664, loss_2: 0.2045643925666809



  1%|▏         | 72/5000 [06:08<5:11:39,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 72, loss = 0.5155651569366455, loss_1: 0.05081323906779289, loss_2: 0.20307564735412598



  1%|▏         | 73/5000 [06:11<5:11:41,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 73, loss = 0.5133422613143921, loss_1: 0.0522555336356163, loss_2: 0.20705989003181458



  1%|▏         | 74/5000 [06:15<5:11:48,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 74, loss = 0.5121817588806152, loss_1: 0.051358722150325775, loss_2: 0.20667794346809387



  2%|▏         | 75/5000 [06:19<5:11:42,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 75, loss = 0.5093652009963989, loss_1: 0.04969567060470581, loss_2: 0.20531943440437317



  2%|▏         | 76/5000 [06:23<5:11:29,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 76, loss = 0.49657633900642395, loss_1: 0.04937291145324707, loss_2: 0.192866712808609



  2%|▏         | 77/5000 [06:27<5:11:20,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 77, loss = 0.4938560426235199, loss_1: 0.04882679507136345, loss_2: 0.1976192444562912



  2%|▏         | 78/5000 [06:30<5:11:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 78, loss = 0.5090724229812622, loss_1: 0.05074286460876465, loss_2: 0.2005477249622345



  2%|▏         | 79/5000 [06:34<5:11:17,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 79, loss = 0.490529328584671, loss_1: 0.04886927455663681, loss_2: 0.19328604638576508



  2%|▏         | 80/5000 [06:38<5:11:07,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 80, loss = 0.4817165732383728, loss_1: 0.04855438321828842, loss_2: 0.18862825632095337



  2%|▏         | 81/5000 [06:42<5:12:01,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 81, loss = 0.4821176826953888, loss_1: 0.04812809079885483, loss_2: 0.19001322984695435



  2%|▏         | 82/5000 [06:46<5:11:45,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 82, loss = 0.4823896884918213, loss_1: 0.04769166558980942, loss_2: 0.19483861327171326



  2%|▏         | 83/5000 [06:49<5:11:27,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 83, loss = 0.478567898273468, loss_1: 0.0481145977973938, loss_2: 0.19194993376731873



  2%|▏         | 84/5000 [06:53<5:11:24,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 84, loss = 0.4691218137741089, loss_1: 0.047077976167201996, loss_2: 0.18374252319335938



  2%|▏         | 85/5000 [06:57<5:11:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 85, loss = 0.46707236766815186, loss_1: 0.0473850779235363, loss_2: 0.18423298001289368



  2%|▏         | 86/5000 [07:01<5:11:12,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 86, loss = 0.4640948176383972, loss_1: 0.04651544615626335, loss_2: 0.18177255988121033



  2%|▏         | 87/5000 [07:05<5:10:57,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 87, loss = 0.4674595892429352, loss_1: 0.0461680106818676, loss_2: 0.18003402650356293



  2%|▏         | 88/5000 [07:08<5:10:59,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 88, loss = 0.4751664400100708, loss_1: 0.04635332152247429, loss_2: 0.18424366414546967



  2%|▏         | 89/5000 [07:12<5:10:45,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 89, loss = 0.47454261779785156, loss_1: 0.0468314103782177, loss_2: 0.19486002624034882



  2%|▏         | 90/5000 [07:16<5:10:42,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 90, loss = 0.48354244232177734, loss_1: 0.045985519886016846, loss_2: 0.20446527004241943



  2%|▏         | 91/5000 [07:20<5:10:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 91, loss = 0.46146446466445923, loss_1: 0.045172374695539474, loss_2: 0.18121692538261414



  2%|▏         | 92/5000 [07:24<5:10:35,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 92, loss = 0.4791884124279022, loss_1: 0.04660218954086304, loss_2: 0.20301580429077148



  2%|▏         | 93/5000 [07:27<5:10:26,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 93, loss = 0.49492496252059937, loss_1: 0.04679695889353752, loss_2: 0.18964305520057678



  2%|▏         | 94/5000 [07:31<5:10:27,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 94, loss = 0.45129522681236267, loss_1: 0.046239934861660004, loss_2: 0.17787234485149384



  2%|▏         | 95/5000 [07:35<5:10:18,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 95, loss = 0.5020687580108643, loss_1: 0.045668601989746094, loss_2: 0.19965049624443054



  2%|▏         | 96/5000 [07:39<5:10:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 96, loss = 0.4807620644569397, loss_1: 0.050912898033857346, loss_2: 0.19481059908866882



  2%|▏         | 97/5000 [07:43<5:10:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 97, loss = 0.4730049967765808, loss_1: 0.046799302101135254, loss_2: 0.18927544355392456



  2%|▏         | 98/5000 [07:46<5:10:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 98, loss = 0.4776347875595093, loss_1: 0.04855670779943466, loss_2: 0.19259926676750183



  2%|▏         | 99/5000 [07:50<5:10:14,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 99, loss = 0.46153199672698975, loss_1: 0.047146499156951904, loss_2: 0.18393240869045258



  2%|▏         | 100/5000 [07:54<5:10:11,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 100, loss = 0.4648047983646393, loss_1: 0.044441401958465576, loss_2: 0.1895567774772644



  2%|▏         | 101/5000 [07:58<5:10:02,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 101, loss = 0.44851523637771606, loss_1: 0.045346617698669434, loss_2: 0.1763879656791687



  2%|▏         | 102/5000 [08:02<5:10:06,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 102, loss = 0.45823270082473755, loss_1: 0.04480767250061035, loss_2: 0.18511979281902313



  2%|▏         | 103/5000 [08:05<5:09:52,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 103, loss = 0.4427400827407837, loss_1: 0.04350658506155014, loss_2: 0.17187386751174927



  2%|▏         | 104/5000 [08:09<5:09:51,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 104, loss = 0.4494851529598236, loss_1: 0.044503651559352875, loss_2: 0.179427832365036



  2%|▏         | 105/5000 [08:13<5:09:52,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 105, loss = 0.4401109218597412, loss_1: 0.04270651191473007, loss_2: 0.167972594499588



  2%|▏         | 106/5000 [08:17<5:09:53,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 106, loss = 0.4327816367149353, loss_1: 0.043545544147491455, loss_2: 0.17020228505134583



  2%|▏         | 107/5000 [08:21<5:09:36,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 107, loss = 0.4429818391799927, loss_1: 0.04280390590429306, loss_2: 0.1690903753042221



  2%|▏         | 108/5000 [08:24<5:09:26,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 108, loss = 0.4226943850517273, loss_1: 0.04203653335571289, loss_2: 0.16360971331596375



  2%|▏         | 109/5000 [08:28<5:09:25,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 109, loss = 0.4324915111064911, loss_1: 0.04379473254084587, loss_2: 0.16565704345703125



  2%|▏         | 110/5000 [08:32<5:09:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 110, loss = 0.4269682765007019, loss_1: 0.04155854508280754, loss_2: 0.16507947444915771



  2%|▏         | 111/5000 [08:36<5:09:19,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 111, loss = 0.4209854304790497, loss_1: 0.04291184991598129, loss_2: 0.16459637880325317



  2%|▏         | 112/5000 [08:40<5:09:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 112, loss = 0.4233226478099823, loss_1: 0.041578732430934906, loss_2: 0.1624184250831604



  2%|▏         | 113/5000 [08:43<5:09:17,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 113, loss = 0.42044201493263245, loss_1: 0.04062080383300781, loss_2: 0.16588488221168518



  2%|▏         | 114/5000 [08:47<5:09:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 114, loss = 0.4158991277217865, loss_1: 0.0406009778380394, loss_2: 0.16238640248775482



  2%|▏         | 115/5000 [08:51<5:09:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 115, loss = 0.4121010899543762, loss_1: 0.04024618864059448, loss_2: 0.15931984782218933



  2%|▏         | 116/5000 [08:55<5:09:19,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 116, loss = 0.4104941785335541, loss_1: 0.04027827829122543, loss_2: 0.16120056807994843



  2%|▏         | 117/5000 [08:59<5:09:12,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 117, loss = 0.4079820513725281, loss_1: 0.039731185883283615, loss_2: 0.15953654050827026



  2%|▏         | 118/5000 [09:02<5:09:07,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 118, loss = 0.40398019552230835, loss_1: 0.039819180965423584, loss_2: 0.15752378106117249



  2%|▏         | 119/5000 [09:06<5:09:05,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 119, loss = 0.4057570695877075, loss_1: 0.03941093012690544, loss_2: 0.159394770860672



  2%|▏         | 120/5000 [09:10<5:09:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 120, loss = 0.4033083915710449, loss_1: 0.038908105343580246, loss_2: 0.158421128988266



  2%|▏         | 121/5000 [09:14<5:09:05,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 121, loss = 0.39811915159225464, loss_1: 0.03880147263407707, loss_2: 0.1547721028327942



  2%|▏         | 122/5000 [09:18<5:08:54,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 122, loss = 0.3966861367225647, loss_1: 0.03844982385635376, loss_2: 0.1546829789876938



  2%|▏         | 123/5000 [09:21<5:08:46,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 123, loss = 0.39718401432037354, loss_1: 0.03805087134242058, loss_2: 0.15620705485343933



  2%|▏         | 124/5000 [09:25<5:08:39,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 124, loss = 0.39282098412513733, loss_1: 0.03790052980184555, loss_2: 0.15302161872386932



  2%|▎         | 125/5000 [09:29<5:08:47,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 125, loss = 0.3900712728500366, loss_1: 0.037619173526763916, loss_2: 0.1517869085073471



  3%|▎         | 126/5000 [09:33<5:08:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 126, loss = 0.3920045495033264, loss_1: 0.037639301270246506, loss_2: 0.15215614438056946



  3%|▎         | 127/5000 [09:37<5:08:26,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 127, loss = 0.3928872346878052, loss_1: 0.03772544860839844, loss_2: 0.15134237706661224



  3%|▎         | 128/5000 [09:40<5:08:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 128, loss = 0.3962934613227844, loss_1: 0.037115078419446945, loss_2: 0.15050753951072693



  3%|▎         | 129/5000 [09:44<5:09:13,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 129, loss = 0.390277624130249, loss_1: 0.037046194076538086, loss_2: 0.14922837913036346



  3%|▎         | 130/5000 [09:48<5:09:04,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 130, loss = 0.3878363370895386, loss_1: 0.03677962347865105, loss_2: 0.15053348243236542



  3%|▎         | 131/5000 [09:52<5:08:50,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 131, loss = 0.3931330740451813, loss_1: 0.03692915290594101, loss_2: 0.1584559977054596



  3%|▎         | 132/5000 [09:56<5:08:41,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 132, loss = 0.39845001697540283, loss_1: 0.03652787208557129, loss_2: 0.16352884471416473



  3%|▎         | 133/5000 [09:59<5:08:34,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 133, loss = 0.38645997643470764, loss_1: 0.036098480224609375, loss_2: 0.15122562646865845



  3%|▎         | 134/5000 [10:03<5:08:19,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 134, loss = 0.3857397437095642, loss_1: 0.036418020725250244, loss_2: 0.14995774626731873



  3%|▎         | 135/5000 [10:07<5:08:09,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 135, loss = 0.3888639509677887, loss_1: 0.035715263336896896, loss_2: 0.15343128144741058



  3%|▎         | 136/5000 [10:11<5:07:54,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 136, loss = 0.38931721448898315, loss_1: 0.03599029779434204, loss_2: 0.1564536988735199



  3%|▎         | 137/5000 [10:15<5:07:55,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 137, loss = 0.3926733732223511, loss_1: 0.036109767854213715, loss_2: 0.15789370238780975



  3%|▎         | 138/5000 [10:18<5:07:58,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 138, loss = 0.3883798122406006, loss_1: 0.03586409613490105, loss_2: 0.15104368329048157



  3%|▎         | 139/5000 [10:22<5:07:49,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 139, loss = 0.3754497468471527, loss_1: 0.03528940677642822, loss_2: 0.14513933658599854



  3%|▎         | 140/5000 [10:26<5:07:32,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 140, loss = 0.3829666078090668, loss_1: 0.036040108650922775, loss_2: 0.14799386262893677



  3%|▎         | 141/5000 [10:30<5:07:32,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 141, loss = 0.38366633653640747, loss_1: 0.03495675325393677, loss_2: 0.149287149310112



  3%|▎         | 142/5000 [10:34<5:07:17,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 142, loss = 0.3762299418449402, loss_1: 0.034630656242370605, loss_2: 0.1464192122220993



  3%|▎         | 143/5000 [10:37<5:07:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 143, loss = 0.388741135597229, loss_1: 0.03525429964065552, loss_2: 0.15108343958854675



  3%|▎         | 144/5000 [10:41<5:07:11,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 144, loss = 0.3850961923599243, loss_1: 0.035926125943660736, loss_2: 0.15127120912075043



  3%|▎         | 145/5000 [10:45<5:07:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 145, loss = 0.3777494430541992, loss_1: 0.034267961978912354, loss_2: 0.14859601855278015



  3%|▎         | 146/5000 [10:49<5:07:03,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 146, loss = 0.3879922032356262, loss_1: 0.03660418838262558, loss_2: 0.15317308902740479



  3%|▎         | 147/5000 [10:53<5:07:03,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 147, loss = 0.3766854703426361, loss_1: 0.03383207321166992, loss_2: 0.14776507019996643



  3%|▎         | 148/5000 [10:56<5:07:03,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 148, loss = 0.379481703042984, loss_1: 0.03401949256658554, loss_2: 0.14950017631053925



  3%|▎         | 149/5000 [11:00<5:07:08,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 149, loss = 0.3745068907737732, loss_1: 0.03348877280950546, loss_2: 0.14366021752357483



  3%|▎         | 150/5000 [11:04<5:07:08,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 150, loss = 0.37234610319137573, loss_1: 0.03392340987920761, loss_2: 0.14373718202114105



  3%|▎         | 151/5000 [11:08<5:07:07,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 151, loss = 0.3885120749473572, loss_1: 0.03354921191930771, loss_2: 0.1553291529417038



  3%|▎         | 152/5000 [11:11<5:06:52,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 152, loss = 0.4005453288555145, loss_1: 0.0333019495010376, loss_2: 0.16502225399017334



  3%|▎         | 153/5000 [11:15<5:06:35,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 153, loss = 0.41373294591903687, loss_1: 0.03304920718073845, loss_2: 0.16303947567939758



  3%|▎         | 154/5000 [11:19<5:06:35,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 154, loss = 0.37491387128829956, loss_1: 0.03372516483068466, loss_2: 0.14722010493278503



  3%|▎         | 155/5000 [11:23<5:06:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 155, loss = 0.3940688669681549, loss_1: 0.03306722640991211, loss_2: 0.15541866421699524



  3%|▎         | 156/5000 [11:27<5:06:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 156, loss = 0.41402342915534973, loss_1: 0.03446177765727043, loss_2: 0.16180437803268433



  3%|▎         | 157/5000 [11:30<5:06:24,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 157, loss = 0.3809342384338379, loss_1: 0.03348449990153313, loss_2: 0.15079627931118011



  3%|▎         | 158/5000 [11:34<5:06:43,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 158, loss = 0.41772228479385376, loss_1: 0.03408859297633171, loss_2: 0.16636048257350922



  3%|▎         | 159/5000 [11:38<5:06:35,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 159, loss = 0.4058680534362793, loss_1: 0.034513115882873535, loss_2: 0.16423720121383667



  3%|▎         | 160/5000 [11:42<5:06:33,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 160, loss = 0.3913193345069885, loss_1: 0.033130668103694916, loss_2: 0.15782761573791504



  3%|▎         | 161/5000 [11:46<5:06:29,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 161, loss = 0.4192989766597748, loss_1: 0.03532761335372925, loss_2: 0.17403638362884521



  3%|▎         | 162/5000 [11:49<5:06:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 162, loss = 0.3797888159751892, loss_1: 0.03322277590632439, loss_2: 0.1483212113380432



  3%|▎         | 163/5000 [11:53<5:06:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 163, loss = 0.4009828269481659, loss_1: 0.035239897668361664, loss_2: 0.1659242957830429



  3%|▎         | 164/5000 [11:57<5:06:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 164, loss = 0.3788692355155945, loss_1: 0.03424715995788574, loss_2: 0.14893850684165955



  3%|▎         | 165/5000 [12:01<5:06:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 165, loss = 0.3841404914855957, loss_1: 0.03521698713302612, loss_2: 0.15377794206142426



  3%|▎         | 166/5000 [12:05<5:06:10,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 166, loss = 0.3787184953689575, loss_1: 0.0344938263297081, loss_2: 0.14822392165660858



  3%|▎         | 167/5000 [12:08<5:06:06,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 167, loss = 0.38044729828834534, loss_1: 0.034175120294094086, loss_2: 0.15005557239055634



  3%|▎         | 168/5000 [12:12<5:06:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 168, loss = 0.3731922507286072, loss_1: 0.03304457664489746, loss_2: 0.14573709666728973



  3%|▎         | 169/5000 [12:16<5:05:58,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 169, loss = 0.3742184042930603, loss_1: 0.033244114369153976, loss_2: 0.14319771528244019



  3%|▎         | 170/5000 [12:20<5:05:52,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 170, loss = 0.37104636430740356, loss_1: 0.03272557258605957, loss_2: 0.14342957735061646



  3%|▎         | 171/5000 [12:24<5:05:49,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 171, loss = 0.3692026138305664, loss_1: 0.03323904797434807, loss_2: 0.14152668416500092



  3%|▎         | 172/5000 [12:27<5:05:49,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 172, loss = 0.36975452303886414, loss_1: 0.0323498472571373, loss_2: 0.14373773336410522



  3%|▎         | 173/5000 [12:31<5:05:44,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 173, loss = 0.3652445077896118, loss_1: 0.032650213688611984, loss_2: 0.1409567892551422



  3%|▎         | 174/5000 [12:35<5:05:28,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 174, loss = 0.36586710810661316, loss_1: 0.03245890140533447, loss_2: 0.14095835387706757



  4%|▎         | 175/5000 [12:39<5:05:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 175, loss = 0.36257514357566833, loss_1: 0.03220641613006592, loss_2: 0.13903099298477173



  4%|▎         | 176/5000 [12:43<5:05:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 176, loss = 0.3627305030822754, loss_1: 0.03182981535792351, loss_2: 0.13872212171554565



  4%|▎         | 177/5000 [12:47<5:06:17,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 177, loss = 0.36098557710647583, loss_1: 0.031584084033966064, loss_2: 0.1386648714542389



  4%|▎         | 178/5000 [12:50<5:05:54,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 178, loss = 0.359244704246521, loss_1: 0.03156912326812744, loss_2: 0.13733212649822235



  4%|▎         | 179/5000 [12:54<5:05:49,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 179, loss = 0.3594256639480591, loss_1: 0.031392037868499756, loss_2: 0.13743841648101807



  4%|▎         | 180/5000 [12:58<5:05:33,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 180, loss = 0.3571487069129944, loss_1: 0.03139837831258774, loss_2: 0.13617101311683655



  4%|▎         | 181/5000 [13:02<5:05:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 181, loss = 0.3567109704017639, loss_1: 0.031086444854736328, loss_2: 0.13568627834320068



  4%|▎         | 182/5000 [13:06<5:05:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 182, loss = 0.3559819757938385, loss_1: 0.03116464614868164, loss_2: 0.13538223505020142



  4%|▎         | 183/5000 [13:09<5:04:59,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 183, loss = 0.354663610458374, loss_1: 0.031056564301252365, loss_2: 0.13440364599227905



  4%|▎         | 184/5000 [13:13<5:04:51,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 184, loss = 0.3545491695404053, loss_1: 0.031110724434256554, loss_2: 0.13463425636291504



  4%|▎         | 185/5000 [13:17<5:04:52,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 185, loss = 0.3549676537513733, loss_1: 0.03099358081817627, loss_2: 0.13486897945404053



  4%|▎         | 186/5000 [13:21<5:04:40,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 186, loss = 0.35602134466171265, loss_1: 0.031025132164359093, loss_2: 0.13641443848609924



  4%|▎         | 187/5000 [13:24<5:04:41,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 187, loss = 0.35454514622688293, loss_1: 0.030855834484100342, loss_2: 0.1354263722896576



  4%|▍         | 188/5000 [13:28<5:04:34,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 188, loss = 0.3514920473098755, loss_1: 0.03068389557301998, loss_2: 0.13253766298294067



  4%|▍         | 189/5000 [13:32<5:04:40,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 189, loss = 0.3514067530632019, loss_1: 0.03058238886296749, loss_2: 0.13280624151229858



  4%|▍         | 190/5000 [13:36<5:04:33,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 190, loss = 0.3547762334346771, loss_1: 0.03045940399169922, loss_2: 0.13644491136074066



  4%|▍         | 191/5000 [13:40<5:04:32,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 191, loss = 0.35685473680496216, loss_1: 0.030337613075971603, loss_2: 0.13830696046352386



  4%|▍         | 192/5000 [13:43<5:04:26,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 192, loss = 0.35356011986732483, loss_1: 0.030310314148664474, loss_2: 0.13542109727859497



  4%|▍         | 193/5000 [13:47<5:04:24,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 193, loss = 0.3502877950668335, loss_1: 0.030295968055725098, loss_2: 0.13254354894161224



  4%|▍         | 194/5000 [13:51<5:04:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 194, loss = 0.35285550355911255, loss_1: 0.030287664383649826, loss_2: 0.13529953360557556



  4%|▍         | 195/5000 [13:55<5:04:19,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 195, loss = 0.35180267691612244, loss_1: 0.030252795666456223, loss_2: 0.13377808034420013



  4%|▍         | 196/5000 [13:59<5:04:07,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 196, loss = 0.3487972021102905, loss_1: 0.030331116169691086, loss_2: 0.13070765137672424



  4%|▍         | 197/5000 [14:02<5:03:57,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 197, loss = 0.3479207754135132, loss_1: 0.030164003372192383, loss_2: 0.1302955448627472



  4%|▍         | 198/5000 [14:06<5:03:54,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 198, loss = 0.34916040301322937, loss_1: 0.03011687658727169, loss_2: 0.1315336525440216



  4%|▍         | 199/5000 [14:10<5:04:03,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 199, loss = 0.3515860438346863, loss_1: 0.030178368091583252, loss_2: 0.13382995128631592



  4%|▍         | 200/5000 [14:14<5:03:56,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 200, loss = 0.3553135395050049, loss_1: 0.029961805790662766, loss_2: 0.13809852302074432



  4%|▍         | 201/5000 [14:18<5:03:43,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 201, loss = 0.3490995168685913, loss_1: 0.029944539070129395, loss_2: 0.13179008662700653



  4%|▍         | 202/5000 [14:21<5:03:32,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 202, loss = 0.34736475348472595, loss_1: 0.029925605282187462, loss_2: 0.13010546565055847



  4%|▍         | 203/5000 [14:25<5:03:34,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 203, loss = 0.3501814901828766, loss_1: 0.030389606952667236, loss_2: 0.1320069283246994



  4%|▍         | 204/5000 [14:29<5:03:24,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 204, loss = 0.3535795211791992, loss_1: 0.029874881729483604, loss_2: 0.13576459884643555



  4%|▍         | 205/5000 [14:33<5:03:27,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 205, loss = 0.35468754172325134, loss_1: 0.02969680353999138, loss_2: 0.1382528394460678



  4%|▍         | 206/5000 [14:37<5:03:17,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 206, loss = 0.3495554029941559, loss_1: 0.029852688312530518, loss_2: 0.1326589286327362



  4%|▍         | 207/5000 [14:40<5:03:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 207, loss = 0.3525722622871399, loss_1: 0.03020864725112915, loss_2: 0.13228151202201843



  4%|▍         | 208/5000 [14:44<5:03:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 208, loss = 0.3569537401199341, loss_1: 0.0299147367477417, loss_2: 0.13428789377212524



  4%|▍         | 209/5000 [14:48<5:03:11,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 209, loss = 0.3482099771499634, loss_1: 0.02992415428161621, loss_2: 0.1310228407382965



  4%|▍         | 210/5000 [14:52<5:03:00,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 210, loss = 0.344419926404953, loss_1: 0.029619812965393066, loss_2: 0.12802478671073914



  4%|▍         | 211/5000 [14:56<5:02:59,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 211, loss = 0.350578248500824, loss_1: 0.02966727875173092, loss_2: 0.13191962242126465



  4%|▍         | 212/5000 [14:59<5:02:56,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 212, loss = 0.3472200632095337, loss_1: 0.029573839157819748, loss_2: 0.1294238269329071



  4%|▍         | 213/5000 [15:03<5:03:02,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 213, loss = 0.3429093360900879, loss_1: 0.02962448261678219, loss_2: 0.12717697024345398



  4%|▍         | 214/5000 [15:07<5:03:06,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 214, loss = 0.3444594144821167, loss_1: 0.02969859167933464, loss_2: 0.12872014939785004



  4%|▍         | 215/5000 [15:11<5:03:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 215, loss = 0.34947025775909424, loss_1: 0.02936077117919922, loss_2: 0.13324397802352905



  4%|▍         | 216/5000 [15:15<5:02:55,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 216, loss = 0.3481242060661316, loss_1: 0.02947831153869629, loss_2: 0.13275837898254395



  4%|▍         | 217/5000 [15:18<5:02:53,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 217, loss = 0.3451768755912781, loss_1: 0.02946551702916622, loss_2: 0.12914568185806274



  4%|▍         | 218/5000 [15:22<5:02:44,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 218, loss = 0.34930670261383057, loss_1: 0.02926812693476677, loss_2: 0.13091784715652466



  4%|▍         | 219/5000 [15:26<5:02:40,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 219, loss = 0.3505687117576599, loss_1: 0.029633402824401855, loss_2: 0.13254563510417938



  4%|▍         | 220/5000 [15:30<5:02:35,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 220, loss = 0.3418213725090027, loss_1: 0.029440145939588547, loss_2: 0.1268155425786972



  4%|▍         | 221/5000 [15:34<5:02:37,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 221, loss = 0.34531575441360474, loss_1: 0.029584785923361778, loss_2: 0.1279665231704712



  4%|▍         | 222/5000 [15:37<5:02:24,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 222, loss = 0.3516882061958313, loss_1: 0.029363295063376427, loss_2: 0.12835919857025146



  4%|▍         | 223/5000 [15:41<5:02:27,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 223, loss = 0.34421291947364807, loss_1: 0.029265185818076134, loss_2: 0.12506155669689178



  4%|▍         | 224/5000 [15:45<5:02:23,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 224, loss = 0.344413161277771, loss_1: 0.02956179901957512, loss_2: 0.12802015244960785



  4%|▍         | 225/5000 [15:49<5:03:17,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 225, loss = 0.34720101952552795, loss_1: 0.029377004131674767, loss_2: 0.1304401010274887



  5%|▍         | 226/5000 [15:53<5:03:02,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 226, loss = 0.3412030339241028, loss_1: 0.02916310355067253, loss_2: 0.12606191635131836



  5%|▍         | 227/5000 [15:56<5:02:43,  3.81s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 227, loss = 0.34156256914138794, loss_1: 0.029356103390455246, loss_2: 0.12439094483852386



  5%|▍         | 228/5000 [16:00<5:02:22,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 228, loss = 0.34687626361846924, loss_1: 0.029267173260450363, loss_2: 0.1270238161087036



  5%|▍         | 229/5000 [16:04<5:02:15,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 229, loss = 0.34244275093078613, loss_1: 0.02924249693751335, loss_2: 0.126024067401886



  5%|▍         | 230/5000 [16:08<5:01:59,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 230, loss = 0.34108108282089233, loss_1: 0.029020309448242188, loss_2: 0.1265188753604889



  5%|▍         | 231/5000 [16:12<5:01:57,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 231, loss = 0.3428035378456116, loss_1: 0.02898508310317993, loss_2: 0.12886832654476166



  5%|▍         | 232/5000 [16:15<5:01:42,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 232, loss = 0.34208568930625916, loss_1: 0.029052218422293663, loss_2: 0.1261637657880783



  5%|▍         | 233/5000 [16:19<5:01:42,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 233, loss = 0.33981403708457947, loss_1: 0.028783779591321945, loss_2: 0.12266236543655396



  5%|▍         | 234/5000 [16:23<5:01:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 234, loss = 0.3393569886684418, loss_1: 0.02873307466506958, loss_2: 0.12288471311330795



  5%|▍         | 235/5000 [16:27<5:01:39,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 235, loss = 0.3361318111419678, loss_1: 0.02865169569849968, loss_2: 0.12264394760131836



  5%|▍         | 236/5000 [16:31<5:01:28,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 236, loss = 0.33567219972610474, loss_1: 0.028675219044089317, loss_2: 0.12198134511709213



  5%|▍         | 237/5000 [16:34<5:01:30,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 237, loss = 0.33877480030059814, loss_1: 0.0285150408744812, loss_2: 0.12325843423604965



  5%|▍         | 238/5000 [16:38<5:01:22,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 238, loss = 0.3425772190093994, loss_1: 0.028505247086286545, loss_2: 0.12394329160451889



  5%|▍         | 239/5000 [16:42<5:01:22,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 239, loss = 0.35695070028305054, loss_1: 0.029015403240919113, loss_2: 0.13300099968910217



  5%|▍         | 240/5000 [16:46<5:01:20,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 240, loss = 0.3531002700328827, loss_1: 0.028620600700378418, loss_2: 0.13494496047496796



  5%|▍         | 241/5000 [16:50<5:01:24,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 241, loss = 0.34068575501441956, loss_1: 0.028898417949676514, loss_2: 0.12629100680351257



  5%|▍         | 242/5000 [16:53<5:01:18,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 242, loss = 0.3426899313926697, loss_1: 0.028656721115112305, loss_2: 0.12685728073120117



  5%|▍         | 243/5000 [16:57<5:01:16,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 243, loss = 0.34567591547966003, loss_1: 0.028727035969495773, loss_2: 0.12971389293670654



  5%|▍         | 244/5000 [17:01<5:01:14,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 244, loss = 0.3418380618095398, loss_1: 0.02905251644551754, loss_2: 0.12436209619045258



  5%|▍         | 245/5000 [17:05<5:01:05,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 245, loss = 0.33768248558044434, loss_1: 0.02846185490489006, loss_2: 0.1220640167593956



  5%|▍         | 246/5000 [17:09<5:00:53,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 246, loss = 0.33604955673217773, loss_1: 0.02897399663925171, loss_2: 0.12319555133581161



  5%|▍         | 247/5000 [17:12<5:00:57,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 247, loss = 0.33603423833847046, loss_1: 0.028426846489310265, loss_2: 0.12167727947235107



  5%|▍         | 248/5000 [17:16<5:00:44,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 248, loss = 0.3367074728012085, loss_1: 0.028430622071027756, loss_2: 0.12139630317687988



  5%|▍         | 249/5000 [17:20<5:00:47,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 249, loss = 0.3359888792037964, loss_1: 0.02861110493540764, loss_2: 0.12152819335460663



  5%|▌         | 250/5000 [17:24<5:00:34,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 250, loss = 0.33501946926116943, loss_1: 0.028463046997785568, loss_2: 0.12151114642620087



  5%|▌         | 251/5000 [17:28<5:00:48,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 251, loss = 0.3369191884994507, loss_1: 0.02852056547999382, loss_2: 0.12186567485332489



  5%|▌         | 252/5000 [17:31<5:00:46,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 252, loss = 0.3399622142314911, loss_1: 0.028305253013968468, loss_2: 0.12370713800191879



  5%|▌         | 253/5000 [17:35<5:00:31,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 253, loss = 0.34371042251586914, loss_1: 0.02854720875620842, loss_2: 0.12608757615089417



  5%|▌         | 254/5000 [17:39<5:00:21,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 254, loss = 0.35349273681640625, loss_1: 0.029842814430594444, loss_2: 0.13168823719024658



  5%|▌         | 255/5000 [17:43<5:00:10,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 255, loss = 0.3402103781700134, loss_1: 0.028798937797546387, loss_2: 0.12403266131877899



  5%|▌         | 256/5000 [17:47<5:00:04,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 256, loss = 0.3446866273880005, loss_1: 0.028397899121046066, loss_2: 0.1274224817752838



  5%|▌         | 257/5000 [17:50<4:59:57,  3.79s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 257, loss = 0.3544284701347351, loss_1: 0.028104543685913086, loss_2: 0.13309064507484436



  5%|▌         | 258/5000 [17:54<5:00:01,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 258, loss = 0.3368266224861145, loss_1: 0.028342386707663536, loss_2: 0.12391053140163422



  5%|▌         | 259/5000 [17:58<5:00:09,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 259, loss = 0.3618953824043274, loss_1: 0.028423985466361046, loss_2: 0.13842067122459412



  5%|▌         | 260/5000 [18:02<5:00:02,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 260, loss = 0.3656459450721741, loss_1: 0.02835458517074585, loss_2: 0.1367369145154953



  5%|▌         | 261/5000 [18:06<4:59:50,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 261, loss = 0.34160754084587097, loss_1: 0.02877432107925415, loss_2: 0.12531647086143494



  5%|▌         | 262/5000 [18:09<4:59:58,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 262, loss = 0.37587404251098633, loss_1: 0.028621237725019455, loss_2: 0.14186780154705048



  5%|▌         | 263/5000 [18:13<4:59:47,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])
in epoch 263, loss = 0.34868907928466797, loss_1: 0.029120683670043945, loss_2: 0.12802819907665253



  5%|▌         | 264/5000 [18:17<4:59:46,  3.80s/it]

conv5_4_interp: torch.Size([1, 64, 64, 64, 64])
conv3_1_sub2_proj: torch.Size([1, 32, 16, 16, 16])
conv_sub4: torch.Size([1, 32, 16, 16, 16])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
x: torch.Size([1, 1, 256, 256, 256])
conv1_sub1: torch.Size([1, 8, 256, 256, 256])
conv2_sub1: torch.Size([1, 8, 256, 256, 256])
conv3_sub1: torch.Size([1, 16, 256, 256, 256])
conv_sub2: torch.Size([1, 32, 64, 64, 64])
conv3_sub1_proj: torch.Size([1, 32, 128, 128, 128])


KeyboardInterrupt: 

In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_modified_model.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1, out_2, out_4 = icnet1(image_1)
        
    loss_4 = dice_loss_3(out_4, label_4_resize_2)
    loss_2 = dice_loss_3(out_2, label_2_resize_2)
    loss_1 = dice_loss_3(out_1, label_1_resize_2)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1 + loss_2 + loss_4
    
    outstr = 'in epoch {}, loss = {}, loss_1: {}'.format(e, loss.item(), loss_1.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer1.zero_grad()
    loss.backward()
    optimizer1.step()

record.close()

In [None]:
# overfit model on single embryo image (original ICNet Model)

from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_original_model.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1, out_2, out_4 = icnet2(image_1)
        
    loss_4 = dice_loss_3(out_4, label_4_resize_4)
    loss_2 = dice_loss_3(out_2, label_2_resize_4)
    loss_1 = dice_loss_3(out_1, label_1_resize_4)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_4 + loss_2 + loss_1
    
    outstr = 'in epoch {}, loss = {}, loss_1: {}, loss_2: {}, loss_4: {}'.format(e, loss.item(), loss_1.item(), loss_2.item(), loss_4.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer2.zero_grad()
    loss.backward()
    optimizer2.step()

record.close()

In [None]:
# overfit model on single embryo image (modified ICNet Model)
# upsample final outputs by a factor of 4 instead of factor 2
import datetime
from loss import *
from tqdm import tqdm

epochs = 5000

record = open('over_fit_deeplab.txt','w+')

for e in tqdm(range(epochs)):
    
    out_1 = deeplab(image_1)
        
    
    loss_1 = dice_loss_3(out_1, label_1)
    
    #loss = loss_4 + loss_2 + loss_1 
    loss = loss_1
    
    outstr = 'in epoch {}, loss = {}'.format(e, loss.item()) + '\n'
    
    print(outstr) 
    record.write(outstr)
    record.flush()
    
    optimizer1.zero_grad()
    loss.backward()
    optimizer1.step()

record.close()