In [1]:
%matplotlib notebook
import os, sys
import logging
import random
import h5py
import shutil
import time
import argparse
import numpy as np
import sigpy.plot as pl
import torch
import sigpy as sp
import torchvision
from torch import optim
from tensorboardX import SummaryWriter
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib
# import custom libraries
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx
from utils.resnet2p1d import generate_model
from utils.flare_utils import roll
# import custom classes
from utils.datasets import SliceData
from subsample_fastmri import MaskFunc
from MoDL_single import UnrolledModel
import argparse
from models.SAmodel import MyNetwork
from models.Unrolled import Unrolled
from models.UnrolledRef import UnrolledRef
from models.UnrolledTransformer import UnrolledTrans
import matplotlib.pyplot as plt
from ImageFusionBlock import ImageFusionBlock
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 0
from ImageFusion_Dualbranch_Fusion.densefuse_net import DenseFuseNet
from ImageFusion_Dualbranch_Fusion.channel_fusion import channel_f as channel_fusion
import itertools
from RCAN import CombinedNetwork
from models.FusionNet import FusionNet
from recon_net_wrap import ViTfuser
from UnrolledViT import UnrolledViT

from fastmri.data import transforms, subsample

In [2]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [3]:
class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=False):
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.rng = np.random.RandomState()
    def get_mask_func(self, factor):
        center_fractions = 0.08 * 4/factor # RandomMaskFunc
        mask_func = subsample.EquiSpacedMaskFunc(
        center_fractions=[center_fractions],
        accelerations=[factor], 
        )
        return mask_func
    
    def __call__(self, kspace, target, reference_kspace, reference,slice):
        im_lowres = abs(sp.ifft(sp.resize(sp.resize(kspace,(256,24)),(256,160))))
        magnitude_vals = im_lowres.reshape(-1)
        k = int(round(0.05 * magnitude_vals.shape[0]))
        scale = magnitude_vals[magnitude_vals.argsort()[::-1][k]]
        kspace = kspace/scale
        target = target/scale
        # Convert everything from numpy arrays to tensors
        kspace_torch = cplx.to_tensor(kspace).float()   
        target_torch = cplx.to_tensor(target).float()  
        target_torch = T.ifft2(T.kspace_cut(T.fft2(target_torch),0.67,0.67)) 
        # Use poisson mask instead
        #mask2 = sp.mri.poisson((256,160), 5, calib=(18, 14), dtype=float, crop_corner=False, return_density=True, seed=0, max_attempts=6, tol=0.01)
        #mask2[128-10:128+9,80-8:80+7] = 1
        #mask_torch = torch.stack([torch.tensor(mask2).float(),torch.tensor(mask2).float()],dim=2)
        #mask_torch = T.kspace_crop(mask_torch,0.67)
        #kspace_torch = T.kspace_cut(mask_torch,0.5)
        kspace_torch = T.awgn_torch(kspace_torch,10,L=1)
        ## Masking
        mask_func = self.get_mask_func(3)
        kspace_torch = T.kspace_cut(kspace_torch,0.67,0.67)
        kspace_torch = transforms.apply_mask(kspace_torch, mask_func)[0]
        # kspace_torch = kspace_torch*mask_torch # For poisson
        
        mask = np.abs(cplx.to_numpy(kspace_torch))!=0
        mask_torch = torch.stack([torch.tensor(mask).float(),torch.tensor(mask).float()],dim=2)
        
        ### Reference addition ###
        im_lowres_ref = abs(sp.ifft(sp.resize(sp.resize(reference_kspace,(256,24)),(256,160))))
        magnitude_vals_ref = im_lowres_ref.reshape(-1)
        k_ref = int(round(0.05 * magnitude_vals_ref.shape[0]))
        scale_ref = magnitude_vals_ref[magnitude_vals_ref.argsort()[::-1][k_ref]]
        reference = reference / scale_ref
        reference_torch = cplx.to_tensor(reference).float()
        reference_torch_kspace = T.fft2(reference_torch)
        reference_torch_kspace = T.kspace_cut(reference_torch_kspace,0.67,0.67)
        reference_torch = T.ifft2(reference_torch_kspace)
        

        return kspace_torch,target_torch,mask_torch, reference_torch 

In [4]:
def create_datasets(args):
    # Generate k-t undersampling masks
    train_mask = MaskFunc([0.08],[4])
    train_data = SliceData(
        root=str(args.data_path),
        transform=DataTransform(train_mask, args),
        sample_rate=1
    )
    return train_data
def create_data_loaders(args):
    train_data = create_datasets(args)
#     print(train_data[0])

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader
def build_optim(args, params):
    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    return optimizer

In [5]:
#Hyper parameters
params = Namespace()
#params.data_path = "./registered_data/patient23b/"
params.data_path = "./registered_data/"
params.batch_size = 4
params.num_grad_steps = 1 #4
params.num_cg_steps = 8 #8
params.share_weights = True
params.modl_lamda = 0.05
params.lr = 0.001 #0.0005
#params.lr = 0.0001
params.weight_decay = 0
params.lr_step_size = 5
params.lr_gamma = 0.3
params.epoch = 31
params.reference_mode = 1
params.reference_lambda = 0.1


In [6]:
train_loader = create_data_loaders(params)



In [7]:
from torchvision import models
#model_ft = models.resnet18(weights='DEFAULT').to(device).requires_grad_(False)
#model_ft.fc = nn.Identity()
#model_ft = models.vgg16(weights='DEFAULT').to(device)#.requires_grad_(False)
from FSloss_wrap import VGGLoss,ResNet18Backbone,FeatureEmbedding,contrastive_loss,VGGPerceptualLoss
#VGGloss = VGGLoss().to(device)
VGGloss = VGGPerceptualLoss().to(device)
#UFLoss = ResNet18Backbone().to(device)
UFLoss = VGGLoss().to(device)
#UFLoss = models.vgg16(pretrained=True).features[:8+1].to(device)
#UFLoss.eval()

def extract_patches(images, patch_size=(10, 10), stride=(10, 10)):
    # images: Tensor of shape (batch_size, 1, 180, 110)
    patches = images.unfold(2, patch_size[0], stride[0]).unfold(3, patch_size[1], stride[1])
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    patches = patches.view(images.size(0), -1, 1, patch_size[0], patch_size[1])
    return patches  # Returns patches of shape (batch_size, num_patches, 1, patch_size[0], patch_size[1])


# Example usage:
patch_size = (10, 10)
stride = (10, 10)  # Non-overlapping patches
def feature_space_loss(features1, features2):
    return F.mse_loss(features1, features2)
def pad_image(images):
    # images: Tensor of shape (batch_size, 1, 172, 108)
    padded_images = F.pad(images, (1, 1, 4, 4), mode='constant', value=0)
    return padded_images  # Shape will be (batch_size, 1, 180, 110)

#modelLoss = ResNet18Backbone().to(device)
#embedding_model = FeatureEmbedding(modelLoss).to(device)
#memory_bank = torch.randn(16, 128)  # Assuming num_patches is the number of different patches stored.
#memory_bank = nn.functional.normalize(memory_bank, p=2, dim=1)  # Normalize the memory bank vectors


from vision_transformer import VisionTransformer
net = VisionTransformer(
  avrg_img_size=320,
  patch_size = (10,10),
  in_chans=1,
  embed_dim=64,
  depth=10,
  num_heads=16

)

from recon_net import ReconNet
model = UnrolledViT(params).to(device)
#model2 = ReconNet(net).to(device)#.requires_grad_(False)
#cp = torch.load('./lsdir-2x+hq50k_vit_epoch_60.pt', map_location=device)
#model2.load_state_dict(cp['model_state_dict'])

"""
model.requires_grad_(False)

for net in model.similaritynets:
    net.param1.requires_grad_(True)
    net.param2.requires_grad_(True)
    #net.recon_net.net.head.requires_grad_(True)
"""
"""
optimizer = optim.Adam(model.parameters(), lr=0.0)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer, 
    max_lr=0.0001,
    steps_per_epoch=len(train_loader),
    epochs=params.epoch,
    pct_start=0.01,
    anneal_strategy='linear',
    cycle_momentum=False,
    base_momentum=0., 
    max_momentum=0.,
    div_factor = 25.,
    final_div_factor=1.,
)
"""
optimizer = build_optim(params,  model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, params.lr_step_size, params.lr_gamma)




shared weights


In [8]:
### Load for fine-tunning
"""
checkpoint_file = "./L2_checkpoints_poisson_x2_FusionNet/model_30.pt"
checkpoint = torch.load(checkpoint_file,map_location=device)
model.load_state_dict(checkpoint['model'])
model.recon_net.requires_grad_(True)
"""
from fastmri.losses import SSIMLoss
criterion = SSIMLoss().to(device)
criterionMSE = nn.MSELoss()
#criterion = nn.L1Loss()

epochs_plot = []
losses_plot = []

for epoch in range(params.epoch):
    model.train()
    avg_loss = 0.
    running_loss = 0.0
    for iter, data in enumerate(train_loader):
        input,target,mask,reference = data
        input = input.to(device).float()
        target = target.to(device).float()
        mask = mask.to(device)
        reference = reference.to(device).float()
        image = T.ifft2(input)
        image = image.permute(0,3,1,2)

        #print(f'image shape: {image.shape}')
        #print(f'reference shape: {reference.shape}')

        target_image = target.permute(0,3,1,2) 
        #print(f'ref size: {reference_image.shape}')
        real_part_tar = target_image[:,0,:,:].unsqueeze(1)
        imag_part_tar = target_image[:,1,:,:].unsqueeze(1)
        mag_tar = torch.sqrt(real_part_tar**2 + imag_part_tar**2).to(device)
        """
        in_pad, wpad, hpad = model2.pad(mag_tar)
        input_norm,mean,std = model2.norm(in_pad.float())
        # Feature extract
        #print(mag_tar.shape)
        mag_tar = torch.cat((mag_tar,mag_tar,mag_tar),dim =1).to(device)
        
        features_target = vgg16_model(torch.cat((mag_tar,mag_tar,mag_tar),dim =1).to(device)).data
        """
        #print(f'Features target: {features_target.shape}')
        im_out = model(input,reference)#.squeeze(3)

        """
        # Plot the concatenated image
        real_part = image[0,0,:,:]
        imag_part = image[0,1,:,:]
        mag_image = torch.sqrt(real_part**2 + imag_part**2)
        real_part_ref = reference[0,:,:,0]
        imag_part_ref = reference[0,:,:,1]
        mag_ref = torch.sqrt(real_part_ref**2 + imag_part_ref**2)
        mag_ref = mag_ref.cpu().detach().numpy()
        print(f'Mag ref: {mag_ref.shape}')
        import matplotlib.pyplot as plt
        %matplotlib inline
        print(im_out.shape)
        print(mag_tar.shape)
        im_out = im_out.cpu().detach().numpy().squeeze(0)
        concat = np.concatenate((mag_ref,mag_image.cpu().detach().numpy(),np.abs(im_out),mag_tar.squeeze(0).cpu().detach().numpy()),axis=1)
        plt.figure(figsize=(12, 6))
        plt.imshow(concat, cmap='gray')
        plt.title('reference                         in                           out                       target   ')
        plt.axis('off')
        plt.show()
        
        l = torch(mag_tar)
        """
        #loss = criterion(im_out,features_target)
        # SSIM
        maxval = torch.max(torch.cat((im_out,mag_tar.permute(0,2,3,1)),dim=1))
        im_out = im_out.permute(0,3,1,2)

        #features_out = vgg16_model(torch.cat((im_out,im_out,im_out),dim =1))
        
        #print(features_out.shape)
        data_range = torch.tensor([maxval], device=device).view(1, 1, 1, 1).expand(im_out.size(0), im_out.size(1), im_out.size(2)-6, im_out.size(3)-6)
        #print(mag_tar.shape)
        #print(im_out.shape)
        #print(data_range.shape)
        # SSIM
        #loss = criterion(im_out, mag_tar.to(device), data_range.to(device))
        # pad:
        im_out_pad = torch.cat((im_out,im_out,im_out),dim =1)/maxval
        mag_tar_pad = torch.cat((mag_tar,mag_tar,mag_tar),dim =1)/maxval
        #loss = nn.MSELoss()(model_ft.features(im_out_pad), model_ft.features(mag_tar_pad))
        
        # SSIM + style
        #loss = VGGloss(im_out,mag_tar.to(device))/40 + criterion(im_out, mag_tar.to(device), data_range.to(device))
        loss = VGGloss(im_out,mag_tar.to(device))/220 + criterion(im_out, mag_tar.to(device), data_range.to(device))
        

        #loss = contrastive_loss(embedding, memory_bank)
        # UFloss 
        """
        padded_out = pad_image(im_out)
        padded_target = pad_image(mag_tar)
        #print(f'padded out size: {padded_out.shape}')
        out_patches = extract_patches(padded_out, patch_size, stride)
        target_patches = extract_patches(padded_target, patch_size, stride)
        #print(f'out_patches size: {out_patches.shape}')
        loss = 0
        # Forward pass for each patch
        for i in range(out_patches.size(1)):
            image_patch = out_patches[:, i]  # Shape: (batch_size, 1, 10, 10)
            target_patch = target_patches[:, i]  # Shape: (batch_size, 1, 10, 10)
            
            #Tripple to use in resnet:
            image_patch = torch.cat((image_patch,image_patch,image_patch),dim=1)
            target_patch = torch.cat((target_patch,target_patch,target_patch),dim=1)
            #print(f'image patch: {image_patch.shape}')
            # Compute feature space loss

            #features = UFLoss(image_patch,target_patch)
            #target_features = UFLoss(target_patch)
        
            #loss_tmp = feature_space_loss(features, target_features) / 3 #divide beacuse of channels
            loss += UFLoss(image_patch,target_patch)
        # L1
        #loss = criterion(features_out, features_out)
        # MSE
        #loss = criterion(im_out,mag_tar.permute(0,2,3,1))
        """
        running_loss = running_loss + loss.item()
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        if iter % 400 == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{params.epoch:3d}] '
                f'Iter = [{iter:4d}/{len(train_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g}'
            )
    #Saving the model
    exp_dir = "L2_checkpoints_poisson_x2_ViT_LR_tests2/"
    if epoch % 5 == 0:
        torch.save(
            {
                'epoch': epoch,
                'params': params,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'exp_dir': exp_dir
            },
            f=os.path.join(exp_dir, 'model_%d.pt'%(epoch))
    )
    running_loss = running_loss / len(train_loader)
    #scheduler.step(running_loss)
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    print(f'Epoch {epoch+1}, Learning rate: {current_lr}')

    #print(f'Epoch {epoch+1}, Learning rate: {scheduler.get_last_lr()[0]}')
    # Append epoch and average loss to plot lists
    epochs_plot.append(epoch)
    losses_plot.append(running_loss)

# Plotting the loss curve
plt.figure()
plt.plot(epochs_plot, losses_plot, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SA unrolled with Reference L2 train Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(exp_dir, 'loss_plot_plato_down.png'))  # Save plot as an image

# Save all_losses to a file for later comparison
losses_file = os.path.join(exp_dir, 'all_losses.txt')
with open(losses_file, 'w') as f:
    for loss in losses_plot:
        f.write(f'{loss}\n')

INFO:root:Epoch = [  0/ 31] Iter = [   0/ 134] Loss = 9.212 Avg Loss = 9.212


Epoch 1, Learning rate: 0.001


INFO:root:Epoch = [  1/ 31] Iter = [   0/ 134] Loss = 3.08 Avg Loss = 3.08


Epoch 2, Learning rate: 0.001


INFO:root:Epoch = [  2/ 31] Iter = [   0/ 134] Loss = 3.096 Avg Loss = 3.096


Epoch 3, Learning rate: 0.001


INFO:root:Epoch = [  3/ 31] Iter = [   0/ 134] Loss = 3.082 Avg Loss = 3.082


Epoch 4, Learning rate: 0.001


INFO:root:Epoch = [  4/ 31] Iter = [   0/ 134] Loss = 3.068 Avg Loss = 3.068


Epoch 5, Learning rate: 0.0003


INFO:root:Epoch = [  5/ 31] Iter = [   0/ 134] Loss = 2.882 Avg Loss = 2.882


Epoch 6, Learning rate: 0.0003


INFO:root:Epoch = [  6/ 31] Iter = [   0/ 134] Loss = 2.523 Avg Loss = 2.523


Epoch 7, Learning rate: 0.0003


INFO:root:Epoch = [  7/ 31] Iter = [   0/ 134] Loss = 2.476 Avg Loss = 2.476


Epoch 8, Learning rate: 0.0003


INFO:root:Epoch = [  8/ 31] Iter = [   0/ 134] Loss = 2.574 Avg Loss = 2.574


Epoch 9, Learning rate: 0.0003


INFO:root:Epoch = [  9/ 31] Iter = [   0/ 134] Loss = 2.43 Avg Loss = 2.43


Epoch 10, Learning rate: 8.999999999999999e-05


INFO:root:Epoch = [ 10/ 31] Iter = [   0/ 134] Loss = 2.196 Avg Loss = 2.196


Epoch 11, Learning rate: 8.999999999999999e-05


INFO:root:Epoch = [ 11/ 31] Iter = [   0/ 134] Loss = 2.197 Avg Loss = 2.197


Epoch 12, Learning rate: 8.999999999999999e-05


INFO:root:Epoch = [ 12/ 31] Iter = [   0/ 134] Loss = 2.283 Avg Loss = 2.283


Epoch 13, Learning rate: 8.999999999999999e-05


INFO:root:Epoch = [ 13/ 31] Iter = [   0/ 134] Loss = 2.29 Avg Loss = 2.29


Epoch 14, Learning rate: 8.999999999999999e-05


INFO:root:Epoch = [ 14/ 31] Iter = [   0/ 134] Loss = 2.254 Avg Loss = 2.254


Epoch 15, Learning rate: 2.6999999999999996e-05


INFO:root:Epoch = [ 15/ 31] Iter = [   0/ 134] Loss = 2.334 Avg Loss = 2.334


Epoch 16, Learning rate: 2.6999999999999996e-05


INFO:root:Epoch = [ 16/ 31] Iter = [   0/ 134] Loss = 2.228 Avg Loss = 2.228


KeyboardInterrupt: 