In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
import nibabel as nib
import numpy as np
from tqdm import tqdm
from model import *

def synthesize_maps(unet, t1w_image, t2w_image, device):
    unet.eval()
    unet.to(device)

    input_image = np.stack([t1w_image, t2w_image], axis=0)
    input_tensor = torch.tensor(input_image, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        output_tensor = unet(input_tensor)

    fa_map, adc_map = output_tensor.squeeze(0).cpu().numpy()
    return fa_map, adc_map

if __name__ == "__main__":
    # Load the trained U-Net model
    unet = initialize_unet(in_channels=2, out_channels=2)
    unet.load_state_dict(torch.load(os.path.join(data_dir, "unet_model.pth")))
    print("Model loaded.")

    # Load new T1w and T2w images for synthesis
    new_t1w_image = nib.load("path/to/new/T1w_image.nii.gz").get_fdata()
    new_t2w_image = nib.load("path/to/new/T2w_image.nii.gz").get_fdata()

    # Synthesize FA and ADC maps
    synthesized_fa_map, synthesized_adc_map = synthesize_maps(unet, new_t1w_image, new_t2w_image, device)

    # Save synthesized maps as NIfTI files
    nib.save(nib.Nifti1Image(synthesized_fa_map, np.eye(4)), "path/to/output/FA_map.nii.gz")
    nib.save(nib.Nifti1Image(synthesized_adc_map, np.eye(4)), "path/to/output/ADC_map.nii.gz")

    print("Synthesized FA and ADC maps saved.")
