In [1]:
import RL_samp
from RL_samp.header import *
from RL_samp.utils import *
from RL_samp.replay_buffer import *
from RL_samp.models import poly_net, val_net
from RL_samp.reconstructors import sigpy_solver
from RL_samp.policies import DQN
from RL_samp.trainers import DeepQL_trainer

from importlib import reload
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


In [2]:
datapath = '/mnt/shared_a/OCMR/OCMR_fully_sampled_images/'
ncfiles = list([])
for file in os.listdir(datapath):
    if file.endswith(".pt"):
        ncfiles.append(file)

In [None]:
loader.reset()
iterMax = 2000
for ind in range(iterMax):
    loader.test()

In [None]:
### DQN Parameter settings

## image parameters
heg = 192
wid = 144

## reconstructor parameters
max_iter = 50
L = 5e-3
solver = 'ADMM'

## trainer parameters
discount    = .5
memory_len  = 20
t_backtrack = 3
base        = 5
budget      = 13
episodes    = 1
save_freq   = 10
batch_size  = 2
ngpu        = 1
lr          = 1e-3
eps         = 1e-3
double_q    = False

In [None]:
reload(RL_samp)
import RL_samp
import RL_samp.trainers
from RL_samp.trainers import DeepQL_trainer


In [None]:
loader  = ocmrLoader(ncfiles,batch_size=1,t_backtrack=t_backtrack)
memory  = ReplayMemory(capacity=memory_len,
                       curr_obs_shape=(t_backtrack,heg,wid),
                       mask_shape=(wid),
                       next_obs_shape=(1,heg,wid),
                       batch_size=batch_size,
                       burn_in=batch_size)
model   = poly_net(samp_dim=wid)
policy  = DQN(model,memory,max_iter=max_iter,ngpu=ngpu,gamma=discount,lr=lr,double_q_mode=double_q,
              solver=solver,max_iter=max_iter,L=L)
trainer = DeepQL_trainer(loader,policy,episodes=episodes,
                         eps=eps,
                         base=base,budget=budget,
                         ngpu=ngpu)
trainer.train()

In [3]:
### AC1 Parameter settings

## image parameters
heg = 192
wid = 144

## reconstructor parameters
max_iter = 50
L = 5e-3
solver = 'ADMM'

## trainer parameters
discount    = .9
t_backtrack = 3
base        = 5
budget      = 13
episodes    = 1
save_freq   = 10
batch_size  = 2
ngpu        = 0
lr          = 1e-3
eps         = 1e-3
double_q    = False

In [4]:
reload(RL_samp)
import RL_samp
from RL_samp.trainers import AC1_trainer

In [5]:
loader  = ocmrLoader(ncfiles,batch_size=1,t_backtrack=t_backtrack)
p_net   = poly_net(samp_dim=wid,softmax=True)
v_net   = val_net()
trainer = AC1_trainer(loader, polynet=p_net, valnet=v_net,
                      base=base, budget=budget,
                      gamma=discount,
                      lr=lr,
                      solver=solver, max_iter=max_iter, L=L,
                      ngpu=ngpu)
trainer.run()

current file: fs_0009_1_5T.pt
Dimension of the current data file: t_ubd 19, slice_ubd 12, rep_ubd 1
epoch [1/100] file [1/43] rep [1/1] slice [1/12]
> [0;32m/home/huangz78/rl_samp/RL_samp/trainers.py[0m(282)[0;36mrun[0;34m()[0m
[0;32m    280 [0;31m                [0mself[0m[0;34m.[0m[0moptimizer_val[0m[0;34m.[0m[0mzero_grad[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    281 [0;31m                [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 282 [0;31m                [0mval_loss[0m [0;34m=[0m [0;34m-[0m [0mdelta[0m [0;34m*[0m [0mv[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    283 [0;31m                [0mval_loss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    284 [0;31m                [0mself[0m[0;34m.[0m[0moptimizer_val[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> unt 313
step: 1, poly_loss: 0.0091,

AttributeError: 'AC1_trainer' object has no attribute 'freq_dqn_checkpoint_save'

## view training history

In [None]:
hist_dir = '/home/huangz78/rl_samp/DQN_hist_8fold_base5_budget13_lr1e-3_epoch1.pt'
data = torch.load(hist_dir)

In [None]:
hist = data['training_record']

In [None]:
figsize = (12,13)
keys = list(hist.keys())
fig, axs = plt.subplots(nrows=3,ncols=2,figsize=figsize)

axs[0,0].plot(hist['horizon_rewards'])
axs[0,0].set_title(keys[-1])

loss_hist = [hist['loss'][i].item() for i in range(len(hist['loss']))]
axs[0,1].plot(loss_hist)
axs[0,1].set_title(keys[0])
axs[0,1].set_yscale('log')

axs[1,0].plot(hist['grad_norm'])
axs[1,0].set_title(keys[1])
axs[1,0].set_yscale('log')


q_values_mean_hist = [hist['q_values_mean'][i].item() for i in range(len(hist['q_values_mean']))]
axs[1,1].plot(q_values_mean_hist)
axs[1,1].set_title(keys[2])

q_values_std_hist = [hist['q_values_std'][i].item() for i in range(len(hist['q_values_std']))]
axs[2,0].plot(q_values_std_hist)
axs[2,0].set_title(keys[3])
axs[2,0].set_yscale('log')

axs[2,1].axis('off')

plt.show()