In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from model import *

In [2]:
import pathlib
import fastmri
from fastmri.data import subsample
from fastmri.data import transforms as T, mri_data

# Create a mask function
mask_func = subsample.RandomMaskFunc(center_fractions=[0.08, 0.04], accelerations=[4, 8])

def data_transform(kspace, mask, target, data_attributes, filename, slice_num):
    # Transform the data into appropriate format
    
    crop_kspace = T.complex_center_crop(T.to_tensor(kspace), (320,320))

    masked_kspace, mask = T.apply_mask(T.to_tensor(kspace), mask_func)   # Apply the mask to k-space
    _, mask2 = T.apply_mask(crop_kspace, mask_func)
    
    ifft_masked_kspace = fastmri.ifft2c(masked_kspace)           # Apply Inverse Fourier Transform to get the complex image
    crop_ifft_mk = T.complex_center_crop(ifft_masked_kspace, (320,320))
    #abs_masked_kspace = fastmri.complex_abs(crop_ifft_mk)   # Compute absolute value to get a real image
    
    original_img = fastmri.ifft2c(T.to_tensor(kspace))   #Inverse fourier transf. to the original image
    crop_original_img = T.complex_center_crop(original_img, (320,320))
    #abs_original_img = fastmri.complex_abs(crop_original_img)
    
    crop_ifft_mk = crop_ifft_mk.unsqueeze(0).unsqueeze(0)
    crop_original_img = crop_original_img.unsqueeze(0).unsqueeze(0)
    
    return crop_ifft_mk, crop_original_img, mask2

dataset = mri_data.SliceDataset(
    root=pathlib.Path('./mimmo'),
    transform=data_transform,
    challenge='singlecoil'
)

In [3]:
import torch.optim as optim

net = Net()

criterion = nn.L1Loss()
optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.95, amsgrad=False)

In [14]:
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    i=0
    for crop_ifft_mk, crop_original_img, mask2 in dataset:
        
        optimizer.zero_grad()
        
        outputs = net(crop_ifft_mk, crop_original_img, mask2)
        
        #figuraoutput = outputs.squeeze(0).squeeze(0)
        #sampled_image = fastmri.ifft2c(figuraoutput)  
        #sampled_image_abs = fastmri.complex_abs(sampled_image)
        
        #fig = plt.figure()
        #plt.imshow(sampled_image_abs.detach().numpy(), cmap='gray')
        #plt.savefig('pippo'+str(i)+'.jpg')
        
        loss = criterion(crop_original_img, outputs)
        
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        print('Epoch n° (%d) loss: %.3f' % (epoch + 1, running_loss))
        
        i += 1
        #running_loss = 0.0
        #print(crop_ifft_mk.shape)
        #print(crop_original_img.shape)
        pass
    
print('Finished Training')

Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss: 0.501
Cascade n° ( 1 )
Epoch n° (1) loss

KeyboardInterrupt: 