In [1]:
%matplotlib notebook
import os, sys
import logging
import random
import h5py
import shutil
import time
import argparse
import numpy as np
import sigpy.plot as pl
import torch
import sigpy as sp
import torchvision
from torch import optim
from tensorboardX import SummaryWriter
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib
# import custom libraries
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx
from utils.resnet2p1d import generate_model
from utils.flare_utils import roll
# import custom classes
from utils.datasets import SliceData
from subsample_fastmri import MaskFunc
from MoDL_single import UnrolledModel
import argparse
from models.SAmodel import MyNetwork
from models.Unrolled import Unrolled
from models.UnrolledRef import UnrolledRef
from models.UnrolledTransformer import UnrolledTrans
import matplotlib.pyplot as plt
from ImageFusionBlock import ImageFusionBlock
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
from models.UnrolledFusion import UnrolledFusion
from fastmri.data import transforms, subsample
%load_ext autoreload
%autoreload 0

In [2]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [3]:
class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=False):
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.rng = np.random.RandomState()
    
    def get_mask_func(self, factor):
        center_fractions = 0.08 * 4/factor # EquiSpacedMaskFuncRandomMaskFunc
        mask_func = subsample.EquiSpacedMaskFunc(
        center_fractions=[center_fractions],
        accelerations=[factor], 
        )
        return mask_func
    
    def __call__(self, kspace, target, reference, reference_kspace,slice):
        im_lowres = abs(sp.ifft(sp.resize(sp.resize(kspace,(256,24)),(256,160))))
        magnitude_vals = im_lowres.reshape(-1)
        k = int(round(0.05 * magnitude_vals.shape[0]))
        scale = magnitude_vals[magnitude_vals.argsort()[::-1][k]]
        kspace = kspace/scale
        target = target/scale
        # Convert everything from numpy arrays to tensors
        kspace_torch = cplx.to_tensor(kspace).float()   
        target_torch = cplx.to_tensor(target).float()  
        target_torch = T.ifft2(T.kspace_cut(T.fft2(target_torch),0.67,0.67)) 
        # Use poisson mask instead
        
        mask2 = sp.mri.poisson((172,108), 3, calib=(18, 14), dtype=float, crop_corner=False, return_density=True, seed=0, max_attempts=6, tol=0.01)
        mask2[86-10:86+10,54-8:54+8] = 1
        mask_torch = torch.stack([torch.tensor(mask2).float(),torch.tensor(mask2).float()],dim=2)
        #kspace_torch = T.kspace_cut(mask_torch,0.5)
        kspace_torch = T.awgn_torch(kspace_torch,10,L=1)
        kspace_torch = T.kspace_cut(kspace_torch,0.67,0.67)

        #print(f'mask: {mask_torch.shape}')
        #print(f'kspace: {kspace_torch.shape}')
        
        #For poisson
        #kspace_torch = kspace_torch * mask_torch
        ## Masking
        mask_func = self.get_mask_func(3)
        kspace_torch = transforms.apply_mask(kspace_torch, mask_func)[0]
        # kspace_torch = kspace_torch*mask_torch # For poisson
        
        """
        mask = np.abs(cplx.to_numpy(kspace_torch))!=0
        mask_torch = torch.stack([torch.tensor(mask).float(),torch.tensor(mask).float()],dim=2)
        """
        ### Reference addition ###
        im_lowres_ref = abs(sp.ifft(sp.resize(sp.resize(reference_kspace,(256,24)),(256,160))))
        magnitude_vals_ref = im_lowres_ref.reshape(-1)
        k_ref = int(round(0.05 * magnitude_vals_ref.shape[0]))
        scale_ref = magnitude_vals_ref[magnitude_vals_ref.argsort()[::-1][k_ref]]
        reference = reference / scale_ref
        reference_torch = cplx.to_tensor(reference).float()
        reference_torch_kspace = T.fft2(reference_torch)
        reference_torch_kspace = T.kspace_cut(reference_torch_kspace,0.67,0.67)
        reference_torch = T.ifft2(reference_torch_kspace)
        

        return kspace_torch,target_torch,mask_torch, reference_torch 

In [4]:
def create_datasets(args):
    # Generate k-t undersampling masks
    train_mask = MaskFunc([0.08],[4])
    train_data = SliceData(
        root=str(args.data_path),
        transform=DataTransform(train_mask, args),
        sample_rate=1
    )
    return train_data
def create_data_loaders(args):
    train_data = create_datasets(args)
#     print(train_data[0])

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader
def build_optim(args, params):
    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    return optimizer

In [5]:
#Hyper parameters
params = Namespace()
#params.data_path = "./registered_data/patient23b/"
params.data_path = "./registered_data/"
params.batch_size = 16
params.num_grad_steps = 4 #4
params.num_cg_steps = 8 #8
params.share_weights = True
params.modl_lamda = 0.05
params.lr = 0.0001
#params.lr = 0.0001
params.weight_decay = 0
params.lr_step_size = 15
params.lr_gamma = 0.3
params.epoch = 101
params.reference_mode = 0
params.reference_lambda = 0.1

In [6]:
train_loader = create_data_loaders(params)

In [7]:
#single_MoDL = UnrolledFusion(params).to(device)
#single_MoDL.FusionModel.requires_grad_(False)

single_MoDL = UnrolledModel(params).to(device)
#single_MoDL = MyNetwork(2,2).to(device)
#single_MoDL = Unrolled(params).to(device)
#single_MoDL = UnrolledRef(params).to(device)
#single_MoDL = UnrolledTrans(params).to(device)
optimizer = build_optim(params, single_MoDL.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, params.lr_step_size, params.lr_gamma)

## Hybrid loss
from FSloss_wrap import VGGPerceptualLoss
#VGGloss = VGGLoss().to(device)
VGGloss = VGGPerceptualLoss().to(device)
"""
optimizer = optim.Adam(single_MoDL.parameters(), lr=0.0)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer, 
    max_lr=0.0002,
    steps_per_epoch=len(train_loader),
    epochs=100,
    pct_start=0.01,
    anneal_strategy='linear',
    cycle_momentum=False,
    base_momentum=0., 
    max_momentum=0.,
    div_factor = 25.,
    final_div_factor=1.,
)

"""

shared weights




"\noptimizer = optim.Adam(single_MoDL.parameters(), lr=0.0)\nscheduler = optim.lr_scheduler.OneCycleLR(\n    optimizer=optimizer, \n    max_lr=0.0002,\n    steps_per_epoch=len(train_loader),\n    epochs=100,\n    pct_start=0.01,\n    anneal_strategy='linear',\n    cycle_momentum=False,\n    base_momentum=0., \n    max_momentum=0.,\n    div_factor = 25.,\n    final_div_factor=1.,\n)\n\n"

In [8]:
### Load for fine-tunning
#checkpoint_file = "./L2_checkpoints_poisson_x4_SAunrolledOF/model_20.pt"
#checkpoint = torch.load(checkpoint_file,map_location=device)
#params = checkpoint["params"]
#single_MoDL.load_state_dict(checkpoint['model'])
import warnings

# Suppress specific deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message=".*Function sporco.util.tikhonov_filter is deprecated.*")

# Suppress specific user warnings from torchvision.models
warnings.filterwarnings("ignore", message=".*Using 'weights' as positional parameter.*")
warnings.filterwarnings("ignore", message=".*Arguments other than a weight enum or `None` for 'weights' are deprecated.*")
warnings.filterwarnings("ignore", category=DeprecationWarning, 
                        message=".*Function sporco.util.tikhonov_filter is deprecated; please use function sporco.signal.tikhonov_filter instead.*")

epochs_plot = []
losses_plot = []

from fastmri.losses import SSIMLoss
criterion = SSIMLoss().to(device)
criterionMSE = nn.MSELoss()
for epoch in range(params.epoch):
    single_MoDL.train()
    avg_loss = 0.
    running_loss = 0.0
    for iter, data in enumerate(train_loader):
        input,target,mask,reference = data
        input = input.to(device)
        target = target.to(device)
        mask = mask.to(device)
        reference = reference.to(device)

        im_out = single_MoDL(input.float(),reference_image=reference,mask=mask)

        # SSIM calc
        
        target_image = target.permute(0,3,1,2) 
        real_part_tar = target_image[:,0,:,:].unsqueeze(1)
        imag_part_tar = target_image[:,1,:,:].unsqueeze(1)
        mag_tar = torch.sqrt(real_part_tar**2 + imag_part_tar**2).to(device)

        real_part_out = im_out[:,:,:,0].unsqueeze(1)
        imag_part_out = im_out[:,:,:,1].unsqueeze(1)
        im_out_abs = torch.sqrt(real_part_out**2 + imag_part_out**2).to(device)

        maxval = torch.max(torch.cat((im_out_abs,mag_tar),dim=1))
        data_range = torch.tensor([maxval], device=device).view(1, 1, 1, 1).expand(im_out_abs.size(0), im_out_abs.size(1), im_out_abs.size(2)-6, im_out_abs.size(3)-6)
        ## SSIM loss
        loss = criterion(im_out_abs, mag_tar.to(device), data_range.to(device)) 
        ## MSE loss
        ## abs
        #loss = criterionMSE(im_out_abs,mag_tar.to(device)) #MSE
        ## complex
        #loss = criterionMSE(im_out,target.to(device))
        ## Hybrid loss
        #loss = VGGloss(im_out_abs,mag_tar.to(device))  +  criterion(im_out_abs, mag_tar.to(device), data_range.to(device))  

        running_loss = running_loss + loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        if iter % 125 == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{params.epoch:3d}] '
                f'Iter = [{iter:4d}/{len(train_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g}'
            )
    #Saving the model
    exp_dir = "checkpoints_10dB/"
    if epoch % 100 == 0:
        torch.save(
            {
                'epoch': epoch,
                'params': params,
                'model': single_MoDL.state_dict(),
                'optimizer': optimizer.state_dict(),
                'exp_dir': exp_dir
            },
            f=os.path.join(exp_dir, 'model_MoDLSSIM_%d.pt'%(epoch))
    )
    running_loss = running_loss / len(train_loader)
    #scheduler.step(running_loss)
    scheduler.step()
    # Append epoch and average loss to plot lists
    epochs_plot.append(epoch)
    losses_plot.append(running_loss)

# Plotting the loss curve
plt.figure()
plt.plot(epochs_plot, losses_plot, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SA unrolled with Reference L2 train Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(exp_dir, 'loss_plot_plato_down.png'))  # Save plot as an image

# Save all_losses to a file for later comparison
losses_file = os.path.join(exp_dir, 'all_losses.txt')
with open(losses_file, 'w') as f:
    for loss in losses_plot:
        f.write(f'{loss}\n')

INFO:root:Epoch = [  0/101] Iter = [   0/  34] Loss = 0.6332 Avg Loss = 0.6332
INFO:root:Epoch = [  1/101] Iter = [   0/  34] Loss = 0.5797 Avg Loss = 0.5797
INFO:root:Epoch = [  2/101] Iter = [   0/  34] Loss = 0.5402 Avg Loss = 0.5402
INFO:root:Epoch = [  3/101] Iter = [   0/  34] Loss = 0.5263 Avg Loss = 0.5263
INFO:root:Epoch = [  4/101] Iter = [   0/  34] Loss = 0.4902 Avg Loss = 0.4902
INFO:root:Epoch = [  5/101] Iter = [   0/  34] Loss = 0.4985 Avg Loss = 0.4985
INFO:root:Epoch = [  6/101] Iter = [   0/  34] Loss = 0.5161 Avg Loss = 0.5161
INFO:root:Epoch = [  7/101] Iter = [   0/  34] Loss = 0.5077 Avg Loss = 0.5077
INFO:root:Epoch = [  8/101] Iter = [   0/  34] Loss = 0.4635 Avg Loss = 0.4635
INFO:root:Epoch = [  9/101] Iter = [   0/  34] Loss = 0.486 Avg Loss = 0.486
INFO:root:Epoch = [ 10/101] Iter = [   0/  34] Loss = 0.431 Avg Loss = 0.431
INFO:root:Epoch = [ 11/101] Iter = [   0/  34] Loss = 0.4863 Avg Loss = 0.4863
INFO:root:Epoch = [ 12/101] Iter = [   0/  34] Loss = 0.

KeyboardInterrupt: 