In [1]:
%matplotlib notebook
import os, sys
import logging
import random
import h5py
import shutil
import time
import argparse
import numpy as np
import sigpy.plot as pl
import torch
import sigpy as sp
import torchvision
from torch import optim
from tensorboardX import SummaryWriter
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib
# import custom libraries
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx
from utils.resnet2p1d import generate_model
from utils.flare_utils import roll
# import custom classes
from utils.datasets import SliceData
from subsample_fastmri import MaskFunc
from MoDL_single import UnrolledModel
import argparse
from models.SAmodel import MyNetwork
from models.Unrolled import Unrolled
from models.UnrolledRef import UnrolledRef
from models.UnrolledTransformer import UnrolledTrans
import matplotlib.pyplot as plt
from ImageFusionBlock import ImageFusionBlock
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 0
from ImageFusion_Dualbranch_Fusion.densefuse_net import DenseFuseNet
from ImageFusion_Dualbranch_Fusion.channel_fusion import channel_f as channel_fusion
import itertools
from RCAN import CombinedNetwork
from models.FusionNet import FusionNet
from recon_net_wrap_reg import ViT
from UnrolledViTreg import UnrolledViTreg
#from UnrolledViTcomplex import UnrolledViT

from fastmri.data import transforms, subsample

In [2]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [3]:
class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=False):
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.rng = np.random.RandomState()
    def get_mask_func(self, factor):
        center_fractions = 0.08 * 4/factor# RandomMaskFuncEquiSpacedMaskFunc
        mask_func = subsample.EquiSpacedMaskFunc(
        center_fractions=[center_fractions],
        accelerations=[factor], 
        )
        return mask_func
    
    def __call__(self, kspace, target, reference_kspace, reference,slice):
        im_lowres = abs(sp.ifft(sp.resize(sp.resize(kspace,(256,24)),(256,160))))
        magnitude_vals = im_lowres.reshape(-1)
        k = int(round(0.05 * magnitude_vals.shape[0]))
        scale = magnitude_vals[magnitude_vals.argsort()[::-1][k]]
        kspace = kspace/scale
        target = target/scale
        # Convert everything from numpy arrays to tensors
        kspace_torch = cplx.to_tensor(kspace).float()   
        target_torch = cplx.to_tensor(target).float()  
        target_torch = T.ifft2(T.kspace_cut(T.fft2(target_torch),0.67,0.67)) 
        # Use poisson mask instead
        #mask2 = sp.mri.poisson((256,160), 5, calib=(18, 14), dtype=float, crop_corner=False, return_density=True, seed=0, max_attempts=6, tol=0.01)
        #mask2[128-10:128+9,80-8:80+7] = 1
        #mask_torch = torch.stack([torch.tensor(mask2).float(),torch.tensor(mask2).float()],dim=2)
        #mask_torch = T.kspace_crop(mask_torch,0.67)
        #kspace_torch = T.kspace_cut(mask_torch,0.5)
        kspace_torch = T.awgn_torch(kspace_torch,20,L=1) # 10dB for simulations
        ## Masking
        mask_func = self.get_mask_func(3)
        kspace_torch = T.kspace_cut(kspace_torch,0.67,0.67)
        kspace_torch = transforms.apply_mask(kspace_torch, mask_func)[0]
        # kspace_torch = kspace_torch*mask_torch # For poisson
        
        mask = np.abs(cplx.to_numpy(kspace_torch))!=0
        mask_torch = torch.stack([torch.tensor(mask).float(),torch.tensor(mask).float()],dim=2)
        
        ### Reference addition ###
        im_lowres_ref = abs(sp.ifft(sp.resize(sp.resize(reference_kspace,(256,24)),(256,160))))
        magnitude_vals_ref = im_lowres_ref.reshape(-1)
        k_ref = int(round(0.05 * magnitude_vals_ref.shape[0]))
        scale_ref = magnitude_vals_ref[magnitude_vals_ref.argsort()[::-1][k_ref]]
        reference = reference / scale_ref
        reference_torch = cplx.to_tensor(reference).float()
        reference_torch_kspace = T.fft2(reference_torch)
        reference_torch_kspace = T.kspace_cut(reference_torch_kspace,0.67,0.67)
        reference_torch = T.ifft2(reference_torch_kspace)
        

        return kspace_torch,target_torch,mask_torch, reference_torch 

In [4]:
def create_datasets(args):
    # Generate k-t undersampling masks
    train_mask = MaskFunc([0.08],[4])
    train_data = SliceData(
        root=str(args.data_path),
        transform=DataTransform(train_mask, args),
        sample_rate=1
    )
    return train_data
def create_data_loaders(args):
    train_data = create_datasets(args)
#     print(train_data[0])

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader
def build_optim(args, params):
    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    return optimizer

In [5]:
#Hyper parameters
params = Namespace()
#params.data_path = "./registered_data/patient23b/"
params.data_path = "./registered_data/"
params.batch_size = 24 #4
params.num_grad_steps = 1 #4
params.num_cg_steps = 8 #8
params.share_weights = True
params.modl_lamda = 0.05
params.lr = 0.0001 #0.0005 # used to be 0.0001
#params.lr = 0.0001
params.weight_decay = 0
params.lr_step_size = 6
params.lr_gamma = 0.3
params.epoch = 101
params.reference_mode = 1
params.reference_lambda = 0.1


In [6]:
train_loader = create_data_loaders(params)

In [7]:
from torchvision import models
#model_ft = models.resnet18(weights='DEFAULT').to(device).requires_grad_(False)
#model_ft.fc = nn.Identity()
#model_ft = models.vgg16(weights='DEFAULT').to(device)#.requires_grad_(False)
from FSloss_wrap import VGGLoss,ResNet18Backbone,FeatureEmbedding,contrastive_loss,VGGPerceptualLoss
#VGGloss = VGGLoss().to(device)
VGGloss = VGGPerceptualLoss().to(device)
#UFLoss = ResNet18Backbone().to(device)
#UFLoss = VGGLoss().to(device)
#UFLoss = models.vgg16(pretrained=True).features[:8+1].to(device)
#UFLoss.eval()

def extract_patches(images, patch_size=(20, 20), stride=(20, 20)):
    # images: Tensor of shape (batch_size, 1, 180, 110)
    patches = images.unfold(2, patch_size[0], stride[0]).unfold(3, patch_size[1], stride[1])
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    patches = patches.view(images.size(0), -1, 1, patch_size[0], patch_size[1])
    return patches  # Returns patches of shape (batch_size, num_patches, 1, patch_size[0], patch_size[1])


# Example usage:
patch_size = (20, 20)
stride = (20, 20)  # Non-overlapping patches
def feature_space_loss(features1, features2):
    return F.mse_loss(features1, features2)
def pad_image(images):
    # images: Tensor of shape (batch_size, 1, 172, 108)
    padded_images = F.pad(images, (6, 6, 4, 4), mode='constant', value=0)
    return padded_images  # Shape will be (batch_size, 1, 180, 120)
#modelLoss = ResNet18Backbone().to(device)
#embedding_model = FeatureEmbedding(modelLoss).to(device)
#memory_bank = torch.randn(16, 128)  # Assuming num_patches is the number of different patches stored.
#memory_bank = nn.functional.normalize(memory_bank, p=2, dim=1)  # Normalize the memory bank vectors


from vision_transformer import VisionTransformer
net = VisionTransformer(
  avrg_img_size=320,
  patch_size = (10,10),
  in_chans=1,
  embed_dim=64,
  depth=10,
  num_heads=16

)

from recon_net import ReconNet
model = UnrolledViTreg(params).to(device)
#model2 = ReconNet(net).to(device)#.requires_grad_(False)
#cp = torch.load('./lsdir-2x+hq50k_vit_epoch_60.pt', map_location=device)
#model2.load_state_dict(cp['model_state_dict'])

"""
model.requires_grad_(False)

for net in model.similaritynets:
    net.param1.requires_grad_(True)
    net.param2.requires_grad_(True)
    #net.recon_net.net.head.requires_grad_(True)
"""
"""
optimizer = optim.Adam(model.parameters(), lr=0.0)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer, 
    max_lr=0.0001,
    steps_per_epoch=len(train_loader),
    epochs=params.epoch,
    pct_start=0.01,
    anneal_strategy='linear',
    cycle_momentum=False,
    base_momentum=0., 
    max_momentum=0.,
    div_factor = 25.,
    final_div_factor=1.,
)
"""
# fine tune training

optimizer = build_optim(params,  model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, params.lr_step_size, params.lr_gamma)

## For ViT only training

optimizer = optim.Adam(model.parameters(), lr=0.0)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer, 
    max_lr=0.0002,
    steps_per_epoch=len(train_loader),
    epochs=100,
    pct_start=0.01,
    anneal_strategy='linear',
    cycle_momentum=False,
    base_momentum=0., 
    max_momentum=0.,
    div_factor = 25.,
    final_div_factor=1.,
)




shared weights


  cp = torch.load('./checkpoints_trained_start/model_100.pt', map_location=self.device) # Try new pretrained 11.12


In [8]:
### Load for fine-tunning
"""
checkpoint_file  = "./checkpoints_ViT_fuser_aug/model_40.pt" 
checkpoint = torch.load(checkpoint_file,map_location=device)
model.load_state_dict(checkpoint['model'])
"""

from fastmri.losses import SSIMLoss
criterion = SSIMLoss().to(device)
criterionMSE = nn.MSELoss()
#criterion = nn.L1Loss()

epochs_plot = []
losses_plot = []

for epoch in range(params.epoch):
    model.train()
    avg_loss = 0.
    running_loss = 0.0
    for iter, data in enumerate(train_loader):
        input,target,mask,reference = data
        input = input.to(device).float()
        target = target.to(device).float()
        mask = mask.to(device)
        reference = reference.to(device).float()
        image = T.ifft2(input)
        image = image.permute(0,3,1,2)

        #print(f'image shape: {image.shape}')
        #print(f'reference shape: {reference.shape}')

        target_image = target.permute(0,3,1,2) 
        #print(f'ref size: {reference_image.shape}')
        real_part_tar = target_image[:,0,:,:].unsqueeze(1)
        imag_part_tar = target_image[:,1,:,:].unsqueeze(1)
        mag_tar = torch.sqrt(real_part_tar**2 + imag_part_tar**2).to(device)
        """
        in_pad, wpad, hpad = model2.pad(mag_tar)
        input_norm,mean,std = model2.norm(in_pad.float())
        # Feature extract
        #print(mag_tar.shape)
        mag_tar = torch.cat((mag_tar,mag_tar,mag_tar),dim =1).to(device)
        
        features_target = vgg16_model(torch.cat((mag_tar,mag_tar,mag_tar),dim =1).to(device)).data
        """
        #print(f'Features target: {features_target.shape}')
        im_out = model(input,reference)#.squeeze(3)

        """
        # Plot the concatenated image
        real_part = image[0,0,:,:]
        imag_part = image[0,1,:,:]
        mag_image = torch.sqrt(real_part**2 + imag_part**2)
        real_part_ref = reference[0,:,:,0]
        imag_part_ref = reference[0,:,:,1]
        mag_ref = torch.sqrt(real_part_ref**2 + imag_part_ref**2)
        mag_ref = mag_ref.cpu().detach().numpy()
        print(f'Mag ref: {mag_ref.shape}')
        import matplotlib.pyplot as plt
        %matplotlib inline
        print(im_out.shape)
        print(mag_tar.shape)
        im_out = im_out.cpu().detach().numpy().squeeze(0)
        concat = np.concatenate((mag_ref,mag_image.cpu().detach().numpy(),np.abs(im_out),mag_tar.squeeze(0).cpu().detach().numpy()),axis=1)
        plt.figure(figsize=(12, 6))
        plt.imshow(concat, cmap='gray')
        plt.title('reference                         in                           out                       target   ')
        plt.axis('off')
        plt.show()
        
        l = torch(mag_tar)
        """
        #loss = criterion(im_out,features_target)
        # SSIM
        
        maxval = torch.max(torch.cat((im_out,mag_tar.permute(0,2,3,1)),dim=1))
        im_out = im_out.permute(0,3,1,2)

        #features_out = vgg16_model(torch.cat((im_out,im_out,im_out),dim =1))
        
        #print(features_out.shape)
        data_range = torch.tensor([maxval], device=device).view(1, 1, 1, 1).expand(im_out.size(0), im_out.size(1), im_out.size(2)-6, im_out.size(3)-6)
        #print(mag_tar.shape)
        #print(im_out.shape)
        #print(data_range.shape)
        # SSIM
        #loss = criterion(im_out, mag_tar.to(device), data_range.to(device))
        # MSE
        #loss = criterionMSE(im_out, mag_tar.to(device))       
        # pad:
        #im_out_pad = torch.cat((im_out,im_out,im_out),dim =1)/maxval
        #mag_tar_pad = torch.cat((mag_tar,mag_tar,mag_tar),dim =1)/maxval
        #loss = nn.MSELoss()(model_ft.features(im_out_pad), model_ft.features(mag_tar_pad))
        
        # SSIM + style - ready loss
        #print(f'ssim is : {+ criterion(im_out, mag_tar.to(device), data_range.to(device))}')
        loss = VGGloss(im_out,mag_tar.to(device))  +  criterion(im_out, mag_tar.to(device), data_range.to(device))  
        #loss = VGGloss(im_out,mag_tar.to(device))  +  50*criterionMSE(im_out, mag_tar.to(device))  
        ## Style +SSIM
        #loss = VGGloss(im_out,mag_tar.to(device))/15 + criterionMSE(im_out, mag_tar.to(device)) # criterion(im_out, mag_tar.to(device), data_range.to(device))   For tests2
        
        # SSIM loss for grant
        #loss = criterion(im_out, mag_tar.to(device), data_range.to(device))
        #loss = criterionMSE(im_out,mag_tar.to(device))
        # L1
        #loss = criterion(features_out, features_out)
        # MSE
        #loss = criterion(im_out,mag_tar.permute(0,2,3,1))
        
        running_loss = running_loss + loss.item()
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        if iter % 400 == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{params.epoch:3d}] '
                f'Iter = [{iter:4d}/{len(train_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g}'
            )
    #Saving the model
    exp_dir = "checkpoints_20dB_paper/"
    if epoch % 100 == 0:
        torch.save(
            {
                'epoch': epoch,
                'params': params,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'exp_dir': exp_dir
            },
            f=os.path.join(exp_dir, 'model_ViT_%d.pt'%(epoch))
    )
    running_loss = running_loss / len(train_loader)
    #scheduler.step(running_loss)
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    print(f'Epoch {epoch+1}, Learning rate: {current_lr}')

    #print(f'Epoch {epoch+1}, Learning rate: {scheduler.get_last_lr()[0]}')
    # Append epoch and average loss to plot lists
    epochs_plot.append(epoch)
    losses_plot.append(running_loss)

# Plotting the loss curve
plt.figure()
plt.plot(epochs_plot, losses_plot, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SA unrolled with Reference L2 train Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(exp_dir, 'loss_plot_plato_down.png'))  # Save plot as an image

# Save all_losses to a file for later comparison
losses_file = os.path.join(exp_dir, 'all_losses.txt')
with open(losses_file, 'w') as f:
    for loss in losses_plot:
        f.write(f'{loss}\n')

INFO:root:Epoch = [  0/101] Iter = [   0/  23] Loss = 0.8557 Avg Loss = 0.8557


Epoch 1, Learning rate: 1.672727272727273e-05


INFO:root:Epoch = [  1/101] Iter = [   0/  23] Loss = 0.6336 Avg Loss = 0.6336


Epoch 2, Learning rate: 2.5454545454545454e-05


INFO:root:Epoch = [  2/101] Iter = [   0/  23] Loss = 0.5635 Avg Loss = 0.5635


Epoch 3, Learning rate: 3.418181818181818e-05


INFO:root:Epoch = [  3/101] Iter = [   0/  23] Loss = 0.57 Avg Loss = 0.57


Epoch 4, Learning rate: 4.290909090909091e-05


INFO:root:Epoch = [  4/101] Iter = [   0/  23] Loss = 0.526 Avg Loss = 0.526


Epoch 5, Learning rate: 5.1636363636363634e-05


INFO:root:Epoch = [  5/101] Iter = [   0/  23] Loss = 0.5201 Avg Loss = 0.5201


Epoch 6, Learning rate: 6.036363636363636e-05


INFO:root:Epoch = [  6/101] Iter = [   0/  23] Loss = 0.4975 Avg Loss = 0.4975


Epoch 7, Learning rate: 6.90909090909091e-05


INFO:root:Epoch = [  7/101] Iter = [   0/  23] Loss = 0.5065 Avg Loss = 0.5065


Epoch 8, Learning rate: 7.781818181818183e-05


INFO:root:Epoch = [  8/101] Iter = [   0/  23] Loss = 0.5058 Avg Loss = 0.5058


Epoch 9, Learning rate: 8.654545454545456e-05


INFO:root:Epoch = [  9/101] Iter = [   0/  23] Loss = 0.5069 Avg Loss = 0.5069


Epoch 10, Learning rate: 9.527272727272728e-05


INFO:root:Epoch = [ 10/101] Iter = [   0/  23] Loss = 0.479 Avg Loss = 0.479


Epoch 11, Learning rate: 0.00010400000000000001


INFO:root:Epoch = [ 11/101] Iter = [   0/  23] Loss = 0.4836 Avg Loss = 0.4836


Epoch 12, Learning rate: 0.00011272727272727272


INFO:root:Epoch = [ 12/101] Iter = [   0/  23] Loss = 0.4657 Avg Loss = 0.4657


Epoch 13, Learning rate: 0.00012145454545454547


INFO:root:Epoch = [ 13/101] Iter = [   0/  23] Loss = 0.5015 Avg Loss = 0.5015


Epoch 14, Learning rate: 0.0001301818181818182


INFO:root:Epoch = [ 14/101] Iter = [   0/  23] Loss = 0.4886 Avg Loss = 0.4886


Epoch 15, Learning rate: 0.0001389090909090909


INFO:root:Epoch = [ 15/101] Iter = [   0/  23] Loss = 0.506 Avg Loss = 0.506


Epoch 16, Learning rate: 0.00014763636363636365


INFO:root:Epoch = [ 16/101] Iter = [   0/  23] Loss = 0.4775 Avg Loss = 0.4775


Epoch 17, Learning rate: 0.00015636363636363637


INFO:root:Epoch = [ 17/101] Iter = [   0/  23] Loss = 0.5149 Avg Loss = 0.5149


Epoch 18, Learning rate: 0.0001650909090909091


INFO:root:Epoch = [ 18/101] Iter = [   0/  23] Loss = 0.5062 Avg Loss = 0.5062


Epoch 19, Learning rate: 0.00017381818181818183


INFO:root:Epoch = [ 19/101] Iter = [   0/  23] Loss = 0.4991 Avg Loss = 0.4991


Epoch 20, Learning rate: 0.00018254545454545455


INFO:root:Epoch = [ 20/101] Iter = [   0/  23] Loss = 0.5047 Avg Loss = 0.5047


Epoch 21, Learning rate: 0.0001912727272727273


INFO:root:Epoch = [ 21/101] Iter = [   0/  23] Loss = 0.5076 Avg Loss = 0.5076


Epoch 22, Learning rate: 0.0002


INFO:root:Epoch = [ 22/101] Iter = [   0/  23] Loss = 0.4837 Avg Loss = 0.4837


Epoch 23, Learning rate: 0.0001999156785243742


INFO:root:Epoch = [ 23/101] Iter = [   0/  23] Loss = 0.4717 Avg Loss = 0.4717


Epoch 24, Learning rate: 0.00019983135704874837


INFO:root:Epoch = [ 24/101] Iter = [   0/  23] Loss = 0.481 Avg Loss = 0.481


Epoch 25, Learning rate: 0.00019974703557312253


INFO:root:Epoch = [ 25/101] Iter = [   0/  23] Loss = 0.4645 Avg Loss = 0.4645


Epoch 26, Learning rate: 0.0001996627140974967


INFO:root:Epoch = [ 26/101] Iter = [   0/  23] Loss = 0.4741 Avg Loss = 0.4741


Epoch 27, Learning rate: 0.0001995783926218709


INFO:root:Epoch = [ 27/101] Iter = [   0/  23] Loss = 0.4639 Avg Loss = 0.4639


Epoch 28, Learning rate: 0.00019949407114624508


INFO:root:Epoch = [ 28/101] Iter = [   0/  23] Loss = 0.4833 Avg Loss = 0.4833


Epoch 29, Learning rate: 0.00019940974967061926


INFO:root:Epoch = [ 29/101] Iter = [   0/  23] Loss = 0.4744 Avg Loss = 0.4744


Epoch 30, Learning rate: 0.0001993254281949934


INFO:root:Epoch = [ 30/101] Iter = [   0/  23] Loss = 0.4753 Avg Loss = 0.4753


Epoch 31, Learning rate: 0.0001992411067193676


INFO:root:Epoch = [ 31/101] Iter = [   0/  23] Loss = 0.4791 Avg Loss = 0.4791


Epoch 32, Learning rate: 0.00019915678524374178


INFO:root:Epoch = [ 32/101] Iter = [   0/  23] Loss = 0.4654 Avg Loss = 0.4654


Epoch 33, Learning rate: 0.00019907246376811596


INFO:root:Epoch = [ 33/101] Iter = [   0/  23] Loss = 0.4813 Avg Loss = 0.4813


Epoch 34, Learning rate: 0.00019898814229249014


INFO:root:Epoch = [ 34/101] Iter = [   0/  23] Loss = 0.4483 Avg Loss = 0.4483


Epoch 35, Learning rate: 0.0001989038208168643


INFO:root:Epoch = [ 35/101] Iter = [   0/  23] Loss = 0.4876 Avg Loss = 0.4876


Epoch 36, Learning rate: 0.00019881949934123848


INFO:root:Epoch = [ 36/101] Iter = [   0/  23] Loss = 0.4807 Avg Loss = 0.4807


Epoch 37, Learning rate: 0.00019873517786561266


INFO:root:Epoch = [ 37/101] Iter = [   0/  23] Loss = 0.428 Avg Loss = 0.428


Epoch 38, Learning rate: 0.00019865085638998684


INFO:root:Epoch = [ 38/101] Iter = [   0/  23] Loss = 0.4746 Avg Loss = 0.4746


Epoch 39, Learning rate: 0.000198566534914361


INFO:root:Epoch = [ 39/101] Iter = [   0/  23] Loss = 0.4797 Avg Loss = 0.4797


Epoch 40, Learning rate: 0.00019848221343873518


INFO:root:Epoch = [ 40/101] Iter = [   0/  23] Loss = 0.458 Avg Loss = 0.458


Epoch 41, Learning rate: 0.00019839789196310936


INFO:root:Epoch = [ 41/101] Iter = [   0/  23] Loss = 0.4622 Avg Loss = 0.4622


Epoch 42, Learning rate: 0.00019831357048748354


INFO:root:Epoch = [ 42/101] Iter = [   0/  23] Loss = 0.459 Avg Loss = 0.459


Epoch 43, Learning rate: 0.00019822924901185773


INFO:root:Epoch = [ 43/101] Iter = [   0/  23] Loss = 0.4651 Avg Loss = 0.4651


Epoch 44, Learning rate: 0.00019814492753623188


INFO:root:Epoch = [ 44/101] Iter = [   0/  23] Loss = 0.4622 Avg Loss = 0.4622


Epoch 45, Learning rate: 0.00019806060606060606


INFO:root:Epoch = [ 45/101] Iter = [   0/  23] Loss = 0.453 Avg Loss = 0.453


Epoch 46, Learning rate: 0.00019797628458498025


INFO:root:Epoch = [ 46/101] Iter = [   0/  23] Loss = 0.457 Avg Loss = 0.457


Epoch 47, Learning rate: 0.00019789196310935443


INFO:root:Epoch = [ 47/101] Iter = [   0/  23] Loss = 0.4675 Avg Loss = 0.4675


Epoch 48, Learning rate: 0.0001978076416337286


INFO:root:Epoch = [ 48/101] Iter = [   0/  23] Loss = 0.479 Avg Loss = 0.479


Epoch 49, Learning rate: 0.00019772332015810276


INFO:root:Epoch = [ 49/101] Iter = [   0/  23] Loss = 0.4572 Avg Loss = 0.4572


Epoch 50, Learning rate: 0.00019763899868247695


INFO:root:Epoch = [ 50/101] Iter = [   0/  23] Loss = 0.4756 Avg Loss = 0.4756


Epoch 51, Learning rate: 0.00019755467720685113


INFO:root:Epoch = [ 51/101] Iter = [   0/  23] Loss = 0.4451 Avg Loss = 0.4451


Epoch 52, Learning rate: 0.0001974703557312253


INFO:root:Epoch = [ 52/101] Iter = [   0/  23] Loss = 0.4358 Avg Loss = 0.4358


Epoch 53, Learning rate: 0.0001973860342555995


INFO:root:Epoch = [ 53/101] Iter = [   0/  23] Loss = 0.4662 Avg Loss = 0.4662


Epoch 54, Learning rate: 0.00019730171277997365


INFO:root:Epoch = [ 54/101] Iter = [   0/  23] Loss = 0.4513 Avg Loss = 0.4513


Epoch 55, Learning rate: 0.00019721739130434783


INFO:root:Epoch = [ 55/101] Iter = [   0/  23] Loss = 0.4476 Avg Loss = 0.4476


Epoch 56, Learning rate: 0.000197133069828722


INFO:root:Epoch = [ 56/101] Iter = [   0/  23] Loss = 0.4396 Avg Loss = 0.4396


Epoch 57, Learning rate: 0.0001970487483530962


INFO:root:Epoch = [ 57/101] Iter = [   0/  23] Loss = 0.4514 Avg Loss = 0.4514


Epoch 58, Learning rate: 0.00019696442687747038


INFO:root:Epoch = [ 58/101] Iter = [   0/  23] Loss = 0.4646 Avg Loss = 0.4646


Epoch 59, Learning rate: 0.00019688010540184453


INFO:root:Epoch = [ 59/101] Iter = [   0/  23] Loss = 0.4379 Avg Loss = 0.4379


Epoch 60, Learning rate: 0.00019679578392621871


INFO:root:Epoch = [ 60/101] Iter = [   0/  23] Loss = 0.4496 Avg Loss = 0.4496


Epoch 61, Learning rate: 0.0001967114624505929


INFO:root:Epoch = [ 61/101] Iter = [   0/  23] Loss = 0.4821 Avg Loss = 0.4821


Epoch 62, Learning rate: 0.00019662714097496708


INFO:root:Epoch = [ 62/101] Iter = [   0/  23] Loss = 0.4499 Avg Loss = 0.4499


Epoch 63, Learning rate: 0.00019654281949934126


INFO:root:Epoch = [ 63/101] Iter = [   0/  23] Loss = 0.4595 Avg Loss = 0.4595


Epoch 64, Learning rate: 0.00019645849802371542


INFO:root:Epoch = [ 64/101] Iter = [   0/  23] Loss = 0.4693 Avg Loss = 0.4693


Epoch 65, Learning rate: 0.0001963741765480896


INFO:root:Epoch = [ 65/101] Iter = [   0/  23] Loss = 0.4521 Avg Loss = 0.4521


Epoch 66, Learning rate: 0.00019628985507246378


INFO:root:Epoch = [ 66/101] Iter = [   0/  23] Loss = 0.4671 Avg Loss = 0.4671


Epoch 67, Learning rate: 0.00019620553359683796


INFO:root:Epoch = [ 67/101] Iter = [   0/  23] Loss = 0.4438 Avg Loss = 0.4438


Epoch 68, Learning rate: 0.00019612121212121214


INFO:root:Epoch = [ 68/101] Iter = [   0/  23] Loss = 0.4487 Avg Loss = 0.4487


Epoch 69, Learning rate: 0.0001960368906455863


INFO:root:Epoch = [ 69/101] Iter = [   0/  23] Loss = 0.45 Avg Loss = 0.45


Epoch 70, Learning rate: 0.00019595256916996048


INFO:root:Epoch = [ 70/101] Iter = [   0/  23] Loss = 0.4278 Avg Loss = 0.4278


Epoch 71, Learning rate: 0.00019586824769433466


INFO:root:Epoch = [ 71/101] Iter = [   0/  23] Loss = 0.467 Avg Loss = 0.467


Epoch 72, Learning rate: 0.00019578392621870885


INFO:root:Epoch = [ 72/101] Iter = [   0/  23] Loss = 0.461 Avg Loss = 0.461


Epoch 73, Learning rate: 0.000195699604743083


INFO:root:Epoch = [ 73/101] Iter = [   0/  23] Loss = 0.4308 Avg Loss = 0.4308


Epoch 74, Learning rate: 0.00019561528326745718


INFO:root:Epoch = [ 74/101] Iter = [   0/  23] Loss = 0.4737 Avg Loss = 0.4737


Epoch 75, Learning rate: 0.00019553096179183136


INFO:root:Epoch = [ 75/101] Iter = [   0/  23] Loss = 0.4418 Avg Loss = 0.4418


Epoch 76, Learning rate: 0.00019544664031620555


INFO:root:Epoch = [ 76/101] Iter = [   0/  23] Loss = 0.4473 Avg Loss = 0.4473


Epoch 77, Learning rate: 0.00019536231884057973


INFO:root:Epoch = [ 77/101] Iter = [   0/  23] Loss = 0.4464 Avg Loss = 0.4464


Epoch 78, Learning rate: 0.00019527799736495388


INFO:root:Epoch = [ 78/101] Iter = [   0/  23] Loss = 0.4683 Avg Loss = 0.4683


Epoch 79, Learning rate: 0.00019519367588932807


INFO:root:Epoch = [ 79/101] Iter = [   0/  23] Loss = 0.4363 Avg Loss = 0.4363


Epoch 80, Learning rate: 0.00019510935441370225


INFO:root:Epoch = [ 80/101] Iter = [   0/  23] Loss = 0.4769 Avg Loss = 0.4769


Epoch 81, Learning rate: 0.00019502503293807643


INFO:root:Epoch = [ 81/101] Iter = [   0/  23] Loss = 0.4866 Avg Loss = 0.4866


Epoch 82, Learning rate: 0.0001949407114624506


INFO:root:Epoch = [ 82/101] Iter = [   0/  23] Loss = 0.4348 Avg Loss = 0.4348


Epoch 83, Learning rate: 0.00019485638998682477


INFO:root:Epoch = [ 83/101] Iter = [   0/  23] Loss = 0.4541 Avg Loss = 0.4541


Epoch 84, Learning rate: 0.00019477206851119895


INFO:root:Epoch = [ 84/101] Iter = [   0/  23] Loss = 0.4425 Avg Loss = 0.4425


Epoch 85, Learning rate: 0.00019468774703557313


INFO:root:Epoch = [ 85/101] Iter = [   0/  23] Loss = 0.4568 Avg Loss = 0.4568


Epoch 86, Learning rate: 0.00019460342555994731


INFO:root:Epoch = [ 86/101] Iter = [   0/  23] Loss = 0.4791 Avg Loss = 0.4791


Epoch 87, Learning rate: 0.0001945191040843215


INFO:root:Epoch = [ 87/101] Iter = [   0/  23] Loss = 0.4551 Avg Loss = 0.4551


Epoch 88, Learning rate: 0.00019443478260869565


INFO:root:Epoch = [ 88/101] Iter = [   0/  23] Loss = 0.4412 Avg Loss = 0.4412


Epoch 89, Learning rate: 0.00019435046113306983


INFO:root:Epoch = [ 89/101] Iter = [   0/  23] Loss = 0.4471 Avg Loss = 0.4471


Epoch 90, Learning rate: 0.00019426613965744402


INFO:root:Epoch = [ 90/101] Iter = [   0/  23] Loss = 0.4502 Avg Loss = 0.4502


Epoch 91, Learning rate: 0.0001941818181818182


INFO:root:Epoch = [ 91/101] Iter = [   0/  23] Loss = 0.452 Avg Loss = 0.452


Epoch 92, Learning rate: 0.00019409749670619238


INFO:root:Epoch = [ 92/101] Iter = [   0/  23] Loss = 0.4263 Avg Loss = 0.4263


Epoch 93, Learning rate: 0.00019401317523056654


INFO:root:Epoch = [ 93/101] Iter = [   0/  23] Loss = 0.4581 Avg Loss = 0.4581


Epoch 94, Learning rate: 0.00019392885375494072


INFO:root:Epoch = [ 94/101] Iter = [   0/  23] Loss = 0.4649 Avg Loss = 0.4649


Epoch 95, Learning rate: 0.0001938445322793149


INFO:root:Epoch = [ 95/101] Iter = [   0/  23] Loss = 0.4525 Avg Loss = 0.4525


Epoch 96, Learning rate: 0.00019376021080368908


INFO:root:Epoch = [ 96/101] Iter = [   0/  23] Loss = 0.435 Avg Loss = 0.435


Epoch 97, Learning rate: 0.00019367588932806326


INFO:root:Epoch = [ 97/101] Iter = [   0/  23] Loss = 0.4454 Avg Loss = 0.4454


Epoch 98, Learning rate: 0.00019359156785243742


INFO:root:Epoch = [ 98/101] Iter = [   0/  23] Loss = 0.4614 Avg Loss = 0.4614


Epoch 99, Learning rate: 0.0001935072463768116


INFO:root:Epoch = [ 99/101] Iter = [   0/  23] Loss = 0.4377 Avg Loss = 0.4377


Epoch 100, Learning rate: 0.00019342292490118578


INFO:root:Epoch = [100/101] Iter = [   0/  23] Loss = 0.4453 Avg Loss = 0.4453


Epoch 101, Learning rate: 0.00019333860342555997


<IPython.core.display.Javascript object>