In [11]:
import scipy.io as sio
import torch
from spenpy.spen import spen
from scipy.io import savemat

import matplotlib.pyplot as plt

class physical_model:
    def __init__(self, img_size=(96, 96)):
        self.InvA, self.AFinal = spen(acq_point=img_size).get_InvA()
    
    def __call__(self, x, phase_map=None):
        x = torch.matmul(self.AFinal * 1j, x)
        if phase_map is not None:
            x[:, 1::2, :] *= torch.exp(1j * phase_map)
        return x
    
    def recons(self, x, phase_map=None):
        if phase_map is not None:
            x[:, 1::2, :] *= torch.exp(-1j * phase_map)
        return torch.matmul(self.InvA, x)

def get_demo(data_root="/home/data1/musong/workspace/2025/8/08-20/tr/test_data", 
             id="IXI050-Guys-0711-T1_idx0025.mat",
             if_phase_map=False):
    PM = physical_model()
    if(if_phase_map):
        phase_map = sio.loadmat(f"{data_root}/phase_map/{id}")
        phase_map = phase_map[next(reversed(phase_map.keys()))]
        phase_map = torch.tensor(phase_map, dtype=torch.complex64).unsqueeze(0)
    
    data = sio.loadmat(f"{data_root}/hr/{id}")
    data = data[next(reversed(data.keys()))]
    data = data / data.max()
    data = torch.tensor(data, dtype=torch.complex64).unsqueeze(0)
    

    if if_phase_map:
        lr = PM(data, phase_map=phase_map)  
    else:
        lr = PM(data)
    
    return data.squeeze().numpy(), lr.squeeze().numpy(), phase_map.squeeze().numpy()


In [12]:
id="IXI050-Guys-0711-T1_idx0025_from_pm"
path = "/home/data1/musong/workspace/2025/8/08-20/tr/test_data"
data, lr, phase_map = get_demo(if_phase_map=True)
savemat(f"{path}/hr/{id}.mat", {"hr": data})
savemat(f"{path}/lr/{id}.mat", {"lr": lr})
savemat(f"{path}/phase_map/{id}.mat", {"phase_map": phase_map})