In [1]:
%%html
<style>
.container{width:85%}
<style>

In [2]:
import RL_samp
from RL_samp.header import *
from RL_samp.utils import *
from RL_samp.replay_buffer import *
from RL_samp.models import poly_net, val_net
from RL_samp.reconstructors import unet_solver
from RL_samp.policies import DQN
from RL_samp.trainers import DeepQL_trainer #, AC1_ET_trainer

from importlib import reload
import matplotlib.pyplot as plt
from matplotlib import colors
import torch.nn.functional as Func

def rolling_mean(x,window):
    window = int(window)

    # Stephen: for large data, the above gets a bit slow, so we can do this:
#   y = np.convolve(x, np.ones(window)/window, mode='valid')
#   return y
    # or https://stackoverflow.com/a/27681394
    cumsum = np.cumsum(np.insert(x, 0, 0))
    return (cumsum[window:] - cumsum[:-window]) / float(window)

# Run Tester : DDQN

In [3]:
hist_dir = '/home/ec2-user/SageMaker/RLsamp/output/DQN_doubleQ_True_ba8_bu16_hist_2023-04-10_BA8BU16_LF_0G_600Epochs.pt'
hist = torch.load(hist_dir)

In [4]:
from RL_samp.trainers import DeepQL_tester
from unet.unet_model_fbr import Unet

In [5]:
# dataloader params
t_backtrack = 8
datapath = '/home/ec2-user/SageMaker/data/OCMR_fully_sampled_images/'

# unet params
unet_rand_dir    = '/home/ec2-user/SageMaker/RLsamp/output/recon_models/unet_lowfreq_rand_1.0_fbr_2_chans_64base8_budget16.pt'
unet_lowfreq_dir = '/home/ec2-user/SageMaker/RLsamp/output/recon_models/unet_lowfreq_rand_0.0_fbr_2_chans_64base8_budget16.pt'
in_chans = 2
chans = 64
num_pool_layers = 6

# policy params
discount = .5

# tester params
fulldim = 144
base    = 8
budget  = 16
device  = torch.device("cuda:0")
save_dir = '/home/ec2-user/SageMaker/RLsamp/output/'

infostr = 'DDQN_test'

In [6]:
ncfiles = np.load('/home/ec2-user/SageMaker/RLsamp/test_files.npz')['files']
loader  = ocmrLoader(ncfiles,batch_size=1,datapath=datapath,t_backtrack=t_backtrack,train_mode=False)

model   = poly_net(samp_dim=fulldim,in_chans=t_backtrack)
model.load_state_dict(hist['dqn_weights'])

unet_rand = Unet(in_chans=in_chans,out_chans=1,chans=chans,
            num_pool_layers=num_pool_layers,drop_prob=0).to(device)
rand_checkpoint = torch.load(unet_rand_dir)
unet_rand.load_state_dict(rand_checkpoint['model_state_dict'])

unet_lowfreq = Unet(in_chans=in_chans,out_chans=1,chans=chans,
            num_pool_layers=num_pool_layers,drop_prob=0).to(device)
lowfreq_checkpoint = torch.load(unet_lowfreq_dir)
unet_lowfreq.load_state_dict(lowfreq_checkpoint['model_state_dict'])

policy = DQN(model,[],device=device,gamma=discount,lr=0,
                  double_q_mode=True,unet=unet_lowfreq,mag_weight=5,maxGuideEp=0)


tester = DeepQL_tester(loader,policy,
                         eps=1e-3,
                         fulldim=fulldim,
                         base=base,
                         budget=budget,
                         save_dir=save_dir,
                         compare=True,
                         rand_eval_unet=unet_rand,
                         lowfreq_eval_unet=unet_lowfreq,
                         infostr=None,
                         device=device
                         )

current file: fs_0021_3T.pt
Dimension of the current data file: t_ubd 22, slice_ubd 1, rep_ubd 1


In [7]:
tester.test()

file [1/10] rep [1/1] slice [1/1]
current file: fs_0058_1_5T.pt
Dimension of the current data file: t_ubd 22, slice_ubd 1, rep_ubd 1
file [2/10] rep [1/1] slice [1/1]
current file: fs_0038_3T.pt
Dimension of the current data file: t_ubd 16, slice_ubd 1, rep_ubd 1
file [3/10] rep [1/1] slice [1/1]
current file: fs_0041_3T.pt
Dimension of the current data file: t_ubd 18, slice_ubd 1, rep_ubd 1
file [4/10] rep [1/1] slice [1/1]
current file: fs_0042_3T.pt
Dimension of the current data file: t_ubd 17, slice_ubd 1, rep_ubd 1
file [5/10] rep [1/1] slice [1/1]
current file: fs_0035_3T.pt
Dimension of the current data file: t_ubd 18, slice_ubd 1, rep_ubd 1
file [6/10] rep [1/1] slice [1/1]
current file: fs_0023_3T.pt
Dimension of the current data file: t_ubd 22, slice_ubd 1, rep_ubd 1
file [7/10] rep [1/1] slice [1/1]
current file: fs_0025_3T.pt
Dimension of the current data file: t_ubd 27, slice_ubd 1, rep_ubd 1
file [8/10] rep [1/1] slice [1/1]
current file: fs_0074_1_5T.pt
Dimension of the 

In [8]:
# eval_hist_path = '/home/ec2-user/SageMaker/RLsamp/output/EVAL_DQN_doubleQ_True_ba8_bu16_2023-05-08.pt'
eval_hist_path = '/home/ec2-user/SageMaker/RLsamp/output/EVAL_DQN_doubleQ_True_ba8_bu16_2023-05-29.pt'
eval_hist = torch.load(eval_hist_path)

In [9]:
print('DDQN rmse test: ', np.mean(eval_hist['testing_record']['rmse']))
print('Rand. rmse test    : ', np.mean(eval_hist['testing_record']['rmse_rand']))
print('LowFreq. rmse test : ', np.mean(eval_hist['testing_record']['rmse_lowfreq']))
print('\n')
print('DDQN ssim test: ', np.mean(eval_hist['testing_record']['ssim']))
print('Rand. ssim test    : ', np.mean(eval_hist['testing_record']['ssim_rand']))
print('LowFreq. ssim test : ', np.mean(eval_hist['testing_record']['ssim_lowfreq']))

DDQN rmse test:  0.3948401342732642
Rand. rmse test    :  0.5666931271386474
LowFreq. rmse test :  0.3948401342732642


DDQN ssim test:  0.7522064061281856
Rand. ssim test    :  0.5785664886946684
LowFreq. ssim test :  0.7522064061281856


# Run Tester : REINFORCE

In [3]:
from RL_samp.REINFORCE import REINFORCE_tester
from unet.unet_model_fbr import Unet

In [31]:
# hist_dir = '/home/ec2-user/SageMaker/RLsamp/output/REINFORCE_hist_2023-04-24_base8_budget16_BA8_BU16_E1000_G0_H1e-2_wTrue_magweg5_rwd1.pt'
hist_dir = '/home/ec2-user/SageMaker/RLsamp/output/REINFORCE_hist_2023-04-25_base8_budget16_BA8_BU16_E1000_G300_H1e-2_wTrue_magweg5_rwd1.pt'
hist = torch.load(hist_dir)

In [32]:
# dataloader params
t_backtrack = 8
datapath = '/home/ec2-user/SageMaker/data/OCMR_fully_sampled_images/'

# unet params
unet_rand_dir = '/home/ec2-user/SageMaker/RLsamp/output/recon_models/unet_lowfreq_rand_1.0_fbr_2_chans_64base8_budget16.pt'
unet_lowfreq_dir = '/home/ec2-user/SageMaker/RLsamp/output/recon_models/unet_lowfreq_rand_0.0_fbr_2_chans_64base8_budget16.pt'
unet_prob_dir = '/home/ec2-user/SageMaker/RLsamp/output/recon_models/unet_prob_rand_0.0_fbr_2_chans_64base8_budget16.pt'

probdistr_dir = '/home/ec2-user/SageMaker/RLsamp/output/probdist_train_base_8.pt'

in_chans = 2
chans = 64
num_pool_layers = 6

# policy params
discount = .5

# tester params
fulldim = 144
base    = 8
budget  = 16
device  = torch.device("cuda:0")
save_dir = '/home/ec2-user/SageMaker/RLsamp/output/'

infostr = 'REINFORCE_test'

In [33]:
ncfiles = np.load('/home/ec2-user/SageMaker/RLsamp/test_files.npz')['files']
loader  = ocmrLoader(ncfiles,batch_size=1,datapath=datapath,t_backtrack=t_backtrack,train_mode=False)

model   = poly_net(samp_dim=fulldim,softmax=True,in_chans=t_backtrack)
model.load_state_dict(hist['polynet_weights'])

unet_rand = Unet(in_chans=in_chans,out_chans=1,chans=chans,
            num_pool_layers=num_pool_layers,drop_prob=0).to(device)
rand_checkpoint = torch.load(unet_rand_dir)
unet_rand.load_state_dict(rand_checkpoint['model_state_dict'])

unet_lowfreq = Unet(in_chans=in_chans,out_chans=1,chans=chans,
            num_pool_layers=num_pool_layers,drop_prob=0).to(device)
lowfreq_checkpoint = torch.load(unet_lowfreq_dir)
unet_lowfreq.load_state_dict(lowfreq_checkpoint['model_state_dict'])

unet_lowfreq = Unet(in_chans=in_chans,out_chans=1,chans=chans,
            num_pool_layers=num_pool_layers,drop_prob=0).to(device)
lowfreq_checkpoint = torch.load(unet_lowfreq_dir)
unet_lowfreq.load_state_dict(lowfreq_checkpoint['model_state_dict'])

unet_prob = Unet(in_chans=in_chans,out_chans=1,chans=chans,
            num_pool_layers=num_pool_layers,drop_prob=0).to(device)
prob_checkpoint = torch.load(unet_prob_dir)
unet_prob.load_state_dict(lowfreq_checkpoint['model_state_dict'])

probdistr = torch.load(probdistr_dir)[f'probability_density_base_{base}']

In [34]:
tester = REINFORCE_tester(loader,model,
                         fulldim=fulldim,
                         base=base,
                         budget=budget,
                         save_dir=save_dir,
                         unet=unet_lowfreq,
                         rand_eval_unet=unet_rand,
                         lowfreq_eval_unet=unet_lowfreq,
                         prob_eval_unet=unet_prob,
                         probdistr=probdistr,
                         infostr=infostr,
                         device=device
                         )

current file: fs_0060_1_5T.pt
Dimension of the current data file: t_ubd 21, slice_ubd 12, rep_ubd 1


In [35]:
tester.run()

file [1/10] rep [1/1] slice [1/12]
current file: fs_0041_3T.pt
Dimension of the current data file: t_ubd 18, slice_ubd 1, rep_ubd 1
file [2/10] rep [1/1] slice [1/1]
current file: fs_0042_3T.pt
Dimension of the current data file: t_ubd 17, slice_ubd 1, rep_ubd 1
file [3/10] rep [1/1] slice [1/1]
current file: fs_0058_1_5T.pt
Dimension of the current data file: t_ubd 22, slice_ubd 1, rep_ubd 1
file [4/10] rep [1/1] slice [1/1]
current file: fs_0038_3T.pt
Dimension of the current data file: t_ubd 16, slice_ubd 1, rep_ubd 1
file [5/10] rep [1/1] slice [1/1]
current file: fs_0074_1_5T.pt
Dimension of the current data file: t_ubd 19, slice_ubd 12, rep_ubd 1
file [6/10] rep [1/1] slice [1/12]
current file: fs_0035_3T.pt
Dimension of the current data file: t_ubd 18, slice_ubd 1, rep_ubd 1
file [7/10] rep [1/1] slice [1/1]
current file: fs_0025_3T.pt
Dimension of the current data file: t_ubd 27, slice_ubd 1, rep_ubd 1
file [8/10] rep [1/1] slice [1/1]
current file: fs_0023_3T.pt
Dimension of t

In [36]:
eval_hist_path = '/home/ec2-user/SageMaker/RLsamp/output/Test_REINFORCE_hist_2023-05-29_base8_budget16_REINFORCE_test_magweg1.0_rwd1.pt'
eval_hist = torch.load(eval_hist_path)

In [37]:
print('REINFORCE rmse test : ', np.mean(eval_hist['testing_record']['rmse']))
print('Rand. rmse test     : ', np.mean(eval_hist['testing_record']['rmse_rand']))
print('LowFreq. rmse test  : ', np.mean(eval_hist['testing_record']['rmse_lowfreq']))
print('prob. rmse test     : ', np.mean(eval_hist['testing_record']['rmse_prob']))
print('\n')
print('REINFORCE ssim test : ', np.mean(eval_hist['testing_record']['ssim']))
print('Rand. ssim test     : ', np.mean(eval_hist['testing_record']['ssim_rand']))
print('LowFreq. ssim test  : ', np.mean(eval_hist['testing_record']['ssim_lowfreq']))
print('Prob. ssim test     : ', np.mean(eval_hist['testing_record']['ssim_prob']))

REINFORCE rmse test :  0.39431259496780435
Rand. rmse test     :  0.5638175830914441
LowFreq. rmse test  :  0.3948401342732642
prob. rmse test     :  0.6159621770170558


REINFORCE ssim test :  0.7520666751388545
Rand. ssim test     :  0.5793199391448487
LowFreq. ssim test  :  0.7522064061281856
Prob. ssim test     :  0.4802507269676366
