In [1]:
import numpy as np 
import torch 
from torch.utils.data import Dataset
from simulator import FSE_signal_TR

T = 32 
TE = 9 
TR = 2800
device = torch.device('cuda')
dtype = torch.float32
batch_size = 50
num_epochs = 150
that_hat_init_angle = 125.
step_size_init_val = 1.

In [2]:
class Decoder_MRIDataset(Dataset):
    """ Decoder MRI dataset with all 180 degrees"""

    # Initialize your data, download, etc.
    def __init__(self):
        rng = np.random.default_rng()
        brain = np.load('data/image.npy')[:, :, 0, :].real.flatten()     
        
        t2 = torch.tensor(brain, dtype=torch.float32)
        t1 = torch.tensor(np.ones(brain.shape[0])*1000., dtype=torch.float32)
        
        self.len = t2.shape[0]
        self.y_data = t2
        self.t1_data = t1
        self.t2_data = t2

    def __getitem__(self, index):
        return self.t1_data[index], self.t2_data[index]

    def __len__(self):
        return self.len


In [3]:
def pbnet(theta_hat, step_size, TE, TR, T1, T2, testFlag=True):
    myt2 = torch.ones((batch_size), dtype=torch.float32, requires_grad=True, device=theta_hat.device)*100.
    y_meas = FSE_signal_TR(theta_hat, TE, TR, T1, T2, B1=1.)
    if testFlag: y_meas = y_meas.detach()

    for kk in range(num_epochs):
        y_est = FSE_signal_TR(theta_hat, TE, TR, T1, myt2, B1=1.)
        res = y_est - y_meas
        loss_dc = torch.sum(res**2)
        g = torch.autograd.grad(loss_dc, 
                                myt2, 
                                create_graph = not testFlag)[0]
        
        myt2 = myt2 - step_size*g           # gradient update
    return myt2

In [4]:
dataset = Decoder_MRIDataset()
data_loader = torch.utils.data.DataLoader(
    dataset = dataset,
    batch_size = batch_size,
    num_workers = 2,
    drop_last = False,
    shuffle = False
)  

In [None]:
final_theta = np.ones((1, T))*that_hat_init_angle
theta_hat_init =  torch.tensor(
            final_theta/180*np.pi, dtype=torch.float32).to(device)
theta_hat = theta_hat_init.detach().clone()
theta_hat.requires_grad = True

myt2_init = torch.ones((batch_size), dtype=torch.float32, requires_grad=True, device=device)*100.

step_size_init = torch.tensor([step_size_init_val], dtype=torch.float32).to(device)
step_size = step_size_init.detach().clone()
step_size.requires_grad = True

In [None]:
theta_hat_init.shape

In [None]:
for ff, (T1, T2) in enumerate(data_loader):
    myt2_0 = myt2_init.detach().clone().to(device)
    t1, t2 = T1.to(device), T2.to(device)
    myt2 = pbnet(theta_hat, step_size, TE, TR, t1, t2,testFlag=True)
    loss = torch.mean((myt2 - t2)**2)
    print(t2[:5])
    print(loss)
    break 

In [None]:
brain = np.load('data/image.npy')[:, :, 0, :].real.flatten()  

In [None]:
brain[:5]