In [1]:
import RL_samp
from RL_samp.header import *
from RL_samp.utils import *
from RL_samp.replay_buffer import *
from RL_samp.models import poly_net, val_net
from RL_samp.reconstructors import sigpy_solver
from RL_samp.policies import DQN
from RL_samp.trainers import DeepQL_trainer, AC1_ET_trainer

from importlib import reload
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


In [None]:
import torch.nn.functional as Func

In [2]:
datapath = '/mnt/shared_a/OCMR/OCMR_fully_sampled_images/'
ncfiles = list([])
for file in os.listdir(datapath):
    if file.endswith(".pt"):
        ncfiles.append(file)

In [None]:
loader.reset()
iterMax = 2000
for ind in range(iterMax):
    loader.test()

In [3]:
### DQN Parameter settings

## image parameters
heg = 192
wid = 144

## reconstructor parameters
max_iter = 50
L = 5e-3
solver = 'ADMM'

## trainer parameters
discount    = .5
memory_len  = 20
t_backtrack = 3
base        = 5
budget      = 13
episodes    = 1
save_freq   = 10
batch_size  = 2
ngpu        = 1
lr          = 1e-3
eps         = 1e-3
double_q    = False

In [None]:
reload(RL_samp)
import RL_samp
import RL_samp.trainers
from RL_samp.trainers import DeepQL_trainer


In [None]:
loader  = ocmrLoader(ncfiles,batch_size=1,t_backtrack=t_backtrack)
memory  = ReplayMemory(capacity=memory_len,
                       curr_obs_shape=(t_backtrack,heg,wid),
                       mask_shape=(wid),
                       next_obs_shape=(1,heg,wid),
                       batch_size=batch_size,
                       burn_in=batch_size)
model   = poly_net(samp_dim=wid)
policy  = DQN(model,memory,max_iter=max_iter,ngpu=ngpu,gamma=discount,lr=lr,double_q_mode=double_q,
              solver=solver,max_iter=max_iter,L=L)
trainer = DeepQL_trainer(loader,policy,episodes=episodes,
                         eps=eps,
                         base=base,budget=budget,
                         ngpu=ngpu)
trainer.train()

In [None]:
### AC1 Parameter settings

## image parameters
heg = 192
wid = 144

## reconstructor parameters
max_iter = 50
L = 5e-3
solver = 'ADMM'

## trainer parameters
discount    = .9
t_backtrack = 3
base        = 5
budget      = 13
episodes    = 1
save_freq   = 10
batch_size  = 2
ngpu        = 0
lr          = 1e-3
eps         = 1e-3
double_q    = False

In [None]:
reload(RL_samp)
import RL_samp
from RL_samp.trainers import AC1_trainer

In [None]:
loader  = ocmrLoader(ncfiles,batch_size=1,t_backtrack=t_backtrack)
p_net   = poly_net(samp_dim=wid,softmax=True)
v_net   = val_net()
trainer = AC1_trainer(loader, polynet=p_net, valnet=v_net,
                      base=base, budget=budget,
                      gamma=discount,
                      lr=lr,
                      solver=solver, max_iter=max_iter, L=L,
                      ngpu=ngpu)
trainer.run()

In [16]:
loader  = ocmrLoader(ncfiles,batch_size=1,t_backtrack=t_backtrack)
p_net   = poly_net(samp_dim=wid,softmax=True)
v_net   = val_net(slope=.5,scale=10)
trainer = AC1_ET_trainer(loader, polynet=p_net, valnet=v_net,
                      base=base, budget=budget,
                      gamma=discount,
                      solver=solver, max_iter=max_iter, L=L, reward_scale=9e2,
                      ngpu=ngpu)

current file: fs_0074_1_5T.pt
Dimension of the current data file: t_ubd 19, slice_ubd 12, rep_ubd 1


In [17]:
trainer.run()

epoch [1/100] file [1/43] rep [1/1] slice [1/12]
> [0;32m/home/huangz78/rl_samp/RL_samp/trainers.py[0m(567)[0;36mrun[0;34m()[0m
[0;32m    565 [0;31m                    [0mvnew[0m  [0;34m=[0m [0mself[0m[0;34m.[0m[0mvalnet[0m[0;34m([0m[0mnext_obs[0m[0;34m)[0m [0;32mif[0m [0mt[0m[0;34m<[0m[0mself[0m[0;34m.[0m[0mhorizon[0m[0;34m-[0m[0;36m1[0m [0;32melse[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    566 [0;31m                    [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 567 [0;31m                    [0mdelta[0m [0;34m=[0m [0mreward[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mgamma[0m [0;34m*[0m [0mvnew[0m  [0;34m-[0m [0mv[0m [0;31m# should check if delta == 0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    568 [0;31m                    [0mprint[0m[0;34m([0m[0;34mf'step {self.steps}, delta {delta.item()}'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    569 [0;31m      

BdbQuit: 

## view training history

In [None]:
hist_dir = '/home/huangz78/rl_samp/AC1_ET_hist_base5_budget13.pt'
data = torch.load(hist_dir)

In [None]:
hist = data['training_record']
print(hist.keys())

In [None]:
rmse = np.array(hist['rmse_cmp'])

In [None]:
plt.figure()
plt.plot(rmse[:,0],label='AC1-ET')
plt.plot(rmse[:,1],label='rand')
plt.plot(rmse[:,2],label='low. freq.')
plt.legend()
plt.show()

In [None]:
figsize = (12,13)
keys = list(hist.keys())
fig, axs = plt.subplots(nrows=2,ncols=3,figsize=figsize)

axs[0,0].plot(hist['horizon_rewards'])
axs[0,0].set_title('horizon_rewards')

axs[0,1].plot(hist['poly_loss'])
axs[0,1].set_title('loss - polynet')
# axs[0,1].set_yscale('log')

axs[0,2].plot(hist['val_loss'])
axs[0,2].set_title('loss - valnet')
# axs[0,2].set_yscale('log')

axs[1,0].plot(hist['poly_grad_norm'])
axs[1,0].set_title('poly_grad_norm')
axs[1,0].set_yscale('log')

axs[1,1].plot(hist['val_grad_norm'])
axs[1,1].set_title('val_grad_norm')
axs[1,1].set_yscale('log')

axs[1,2].plot(hist['action_prob'])
axs[1,2].set_title('action prob')

plt.show()

In [None]:
hist_dir = '/home/huangz78/rl_samp/DQN_doubleQ_True_hist.pt'
data = torch.load(hist_dir)
hist = data['training_record']
print(hist.keys())

In [None]:
figsize = (12,13)
keys = list(hist.keys())
fig, axs = plt.subplots(nrows=3,ncols=2,figsize=figsize)

axs[0,0].plot(hist['horizon_rewards'])
axs[0,0].set_title(keys[-1])

# loss_hist = [hist['loss'][i].item() for i in range(len(hist['loss']))]
axs[0,1].plot(hist['loss'])
axs[0,1].set_title(keys[0])
axs[0,1].set_yscale('log')

axs[1,0].plot(hist['grad_norm'])
axs[1,0].set_title(keys[1])
axs[1,0].set_yscale('log')


# q_values_mean_hist = [hist['q_values_mean'][i].item() for i in range(len(hist['q_values_mean']))]
axs[1,1].plot(hist['q_values_mean'])
axs[1,1].set_title(keys[2])

# q_values_std_hist = [hist['q_values_std'][i].item() for i in range(len(hist['q_values_std']))]
axs[2,0].plot(hist['q_values_std'])
axs[2,0].set_title(keys[3])
axs[2,0].set_yscale('log')

axs[2,1].plot(hist['rmse'],label='DQN')
axs[2,1].plot(hist['rmse_lowfreq'],label='low freq.')
axs[2,1].plot(hist['rmse_rand'],'.',linewidth=.5,label='rand')
axs[2,1].legend(loc='best')

plt.show()