In [1]:
import numpy as np 
import torch 
from torch.utils.data import Dataset
from simulator import FSE_signal_TR
from tqdm import tqdm 

T = 32 
TE = 9 
TR = 2800
device = torch.device('cuda')
dtype = torch.float32
batch_size = 5000
num_epochs = 150
that_hat_init_angle = 125.
step_size_init_val = 1.

In [2]:
class Decoder_MRIDataset(Dataset):
    """ Decoder MRI dataset with all 180 degrees"""

    # Initialize your data, download, etc.
    def __init__(self):
        rng = np.random.default_rng()
        brain = np.abs(np.load('data/image.npy'))[:, :, 0, :].reshape(288*288, 32)
        
        self.len = brain.shape[0]
        self.y_data = brain
        self.t1_data = torch.tensor(np.ones(brain.shape[0])*1000., dtype=torch.float32)

    def __getitem__(self, index):
        return self.t1_data[index], self.y_data[index, :]

    def __len__(self):
        return self.len


In [3]:
def pbnet(y_meas, theta_hat, step_size, TE, TR, T1, testFlag=True):
    myt2 = torch.ones((y_meas.shape[0]), dtype=torch.float32, requires_grad=True, device=theta_hat.device)*100.
    if testFlag: y_meas = y_meas.detach()

    for kk in range(num_epochs):
        y_est = FSE_signal_TR(theta_hat, TE, TR, T1, myt2, B1=1.).squeeze()
        res = y_est - y_meas
        loss_dc = torch.sum(res**2)
        g = torch.autograd.grad(loss_dc, 
                                myt2, 
                                create_graph = not testFlag)[0]
        
        myt2 = myt2 - step_size*g           # gradient update
    return myt2

In [4]:
dataset = Decoder_MRIDataset()
data_loader = torch.utils.data.DataLoader(
    dataset = dataset,
    batch_size = batch_size,
    num_workers = 2,
    drop_last = False,
    shuffle = False
)  

In [5]:
final_theta = np.ones((1, T))*that_hat_init_angle
theta_hat_init =  torch.tensor(
            final_theta/180*np.pi, dtype=torch.float32).to(device)
theta_hat = theta_hat_init.detach().clone()
theta_hat.requires_grad = True

myt2_init = torch.ones((batch_size), dtype=torch.float32, requires_grad=True, device=device)*100.

step_size_init = torch.tensor([step_size_init_val], dtype=torch.float32).to(device)
step_size = step_size_init.detach().clone()
step_size.requires_grad = True

In [6]:
theta_hat_init.shape

torch.Size([1, 32])

In [7]:
for ff, (T1, y) in tqdm(enumerate(data_loader)):
    myt2_0 = myt2_init.detach().clone().to(device)
    t1 = T1.to(device)
    y_meas = y.to(device) 
    myt2 = pbnet(y_meas, theta_hat, step_size, TE, TR, t1, testFlag=True)

17it [05:11, 18.31s/it]
