In [1]:
import sys
sys.path.append('/share/gpu0/jjwhit/rcGAN/')

In [2]:
import torch
import yaml
import types
import json

import numpy as np
import matplotlib.patches as patches

from data.lightning.MassMappingDataModule import MMDataModule
from utils.parse_args import create_arg_parser
from pytorch_lightning import seed_everything
from models.lightning.mmGAN import mmGAN
from utils.mri.math import tensor_to_complex_np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from scipy import ndimage



In [4]:
def load_object(dct):
    return types.SimpleNamespace(**dct)

In [3]:
# Import model here
test_plot_model = mmGAN.load_from_checkpoint('/share/gpu0/jjwhit/mass_map/mm_models/mmgan_debug/checkpoint_best.ckpt')

In [5]:
with open('/share/gpu0/jjwhit/rcGAN/configs/mass_map.yml', 'r') as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
    cfg = json.loads(json.dumps(cfg), object_hook=load_object)

dm = MMDataModule(cfg)
fig_count = 1
dm.setup()
test_loader = dm.test_dataloader()

In [6]:
test_plot_model.cuda()
test_plot_model.eval()

mmGAN(
  (generator): UNetModel(
    (down_sample_layers): ModuleList(
      (0): ConvDownBlock(
        (conv_1): Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (res): ResidualBlock(
          (conv_block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): PReLU(num_parameters=1)
            (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (5): PReLU(num_parameters=1)
          )
          (conv_1x1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        )
        (conv_3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): PReLU(

In [9]:
for i, data in enumerate(test_loader):
    y, x, mean, std = data
    y = y.cuda()
    x = x.cuda()
    mean = mean.cuda()
    std = std.cuda()

    gens_rcgan = torch.zeros(
        size=(y.size(0), cfg.num_z_test, cfg.in_chans // 2, cfg.im_size, cfg.im_size, 2)).cuda()
    
    for z in range(cfg.num_z_test):
        gens_rcgan[:, z, :, :, :, :] = test_plot_model.reformat(test_plot_model.forward(y))

    avg_rcgan = torch.mean(gens_rcgan, dim=1)

    gt = test_plot_model.reformat(x)
    zfr = test_plot_model.reformat(y)

    for j in range(y.size(0)):
        np_avgs = {
            'rcgan': None,
        }

        np_samps = {
            'rcgan': [],
        }

        np_stds = {
            'rcgan': None,
        }

        np_gt = None


        np_gt = ndimage.rotate(
            torch.tensor(tensor_to_complex_np((gt[j] * std[j] + mean[j]).cpu())).abs().numpy(), 180) #What is the 180 for?
        np_zfr = ndimage.rotate(
            torch.tensor(tensor_to_complex_np((zfr[j] * std[j] + mean[j]).cpu())).abs().numpy(), 180)

        np_avgs['rcgan'] = ndimage.rotate(
            torch.tensor(tensor_to_complex_np((avg_rcgan[j] * std[j] + mean[j]).cpu())).abs().numpy(),
            180)

        for z in range(cfg.num_z_test):
            np_samps['rcgan'].append(ndimage.rotate(torch.tensor(
                tensor_to_complex_np((gens_rcgan[j, z] * std[j] + mean[j]).cpu())).abs().numpy(), 180))

        np_stds['rcgan'] = np.std(np.stack(np_samps['rcgan']), axis=0)

        method = 'rcgan'
        zoom_startx = np.random.randint(120, 250) #What is this section for?
        zoom_starty1 = np.random.randint(30, 80)
        zoom_starty2 = np.random.randint(260, 300)

        p = np.random.rand()
        zoom_starty = zoom_starty1
        if p <= 0.5:
            zoom_starty = zoom_starty2

        zoom_length = 80

        x_coord = zoom_startx + zoom_length
        y_coords = [zoom_starty, zoom_starty + zoom_length]

        # Global recon, error, std
        nrow = 1
        ncol = 4

        fig = plt.figure(figsize=(ncol + 1, nrow + 1))

        gs = gridspec.GridSpec(nrow, ncol,
                               wspace=0.0, hspace=0.0,
                               top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                               left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1))

        ax = plt.subplot(gs[0, 0])
        ax.imshow(np_gt, cmap='gray', vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title("Truth")

        ax = plt.subplot(gs[0, 1])
        ax.imshow(np_avgs[method], cmap='gray', vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(method)


        ax = plt.subplot(gs[0, 2])
        im = ax.imshow(2 * np.abs(np_avgs[method] - np_gt), cmap='jet', vmin=0,
                       vmax=np.max(np.abs(np_avgs['rcgan'] - np_gt)))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title("Error")


        ax = plt.subplot(gs[0, 3])
        ax.imshow(np_stds[method], cmap='viridis', vmin=0, vmax=np.max(np_stds['rcgan']))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title("Std. Dev.")

        plt.savefig(f'/share/gpu0/jjwhit/test_figures/test_fig_avg_err_std_{fig_count}.png', bbox_inches='tight', dpi=300)
        plt.close(fig)

        nrow = 1
        ncol = 8

        fig = plt.figure(figsize=(ncol + 1, nrow + 1))

        gs = gridspec.GridSpec(nrow, ncol,
                               wspace=0.0, hspace=0.0,
                               top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                               left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1))

        ax = plt.subplot(gs[0, 0])
        ax.imshow(np_gt, cmap='gray', vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('Truth')

        ax1 = ax

        rect = patches.Rectangle((zoom_startx, zoom_starty), zoom_length, zoom_length, linewidth=1,
                                 edgecolor='r',
                                 facecolor='none')

        # Add the patch to the Axes
        ax.add_patch(rect)

        ax = plt.subplot(gs[0, 1])
        ax.imshow(np_gt[zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length],
                  cmap='gray',
                  vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('Truth')

        connection_path_1 = patches.ConnectionPatch([zoom_startx + zoom_length, zoom_starty],
                                                    [0, 0], coordsA=ax1.transData,
                                                    coordsB=ax.transData, color='r')
        fig.add_artist(connection_path_1)
        connection_path_2 = patches.ConnectionPatch([zoom_startx + zoom_length, zoom_starty + zoom_length], [0, zoom_length],
                                                    coordsA=ax1.transData,
                                                    coordsB=ax.transData, color='r')
        fig.add_artist(connection_path_2)

        ax = plt.subplot(gs[0, 2])
        ax.imshow(
            np_avgs[method][zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length],
            cmap='gray', vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('32-Avg.')

        ax = plt.subplot(gs[0, 3])
        avg = np.zeros((384, 384))
        for l in range(4):
            avg += np_samps[method][l]

        avg = avg / 4
        
        ax.imshow(
            avg[zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length],
            cmap='gray', vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('4-Avg.')

        ax = plt.subplot(gs[0, 4])
        avg = np.zeros((384, 384))
        for l in range(2):
            avg += np_samps[method][l]

        avg = avg / 2
        ax.imshow(
            avg[zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length],
            cmap='gray', vmin=0, vmax=0.7 * np.max(np_gt))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('2-Avg.')

        for samp in range(2):
            ax = plt.subplot(gs[0, samp + 5])
            ax.imshow(np_samps[method][samp][zoom_starty:zoom_starty + zoom_length,
                      zoom_startx:zoom_startx + zoom_length], cmap='gray', vmin=0,
                      vmax=0.7 * np.max(np_gt))
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(f'Sample {samp + 1}')


        ax = plt.subplot(gs[0, 7])
        ax.imshow(np_stds[method][zoom_starty:zoom_starty + zoom_length,
                  zoom_startx:zoom_startx + zoom_length], cmap='viridis', vmin=0,
                  vmax=np.max(np_stds['rcgan']))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('Std. Dev.')

        plt.savefig(f'/share/gpu0/jjwhit/test_figures/zoomed_avg_samps_{fig_count}.png', bbox_inches='tight', dpi=300)
        plt.close(fig)

        if fig_count == args.num_figs:
            exit()
        fig_count += 1



OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 79.10 GiB total capacity; 73.68 GiB already allocated; 1.82 GiB free; 75.82 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF