In [None]:
from reinforce import *
from header import *
from utils import *
from replay_buffer import *
from models import poly_net
from reconstructors import sigpy_solver
from importlib import reload
memory_len = 10
t_backtrack = 3
heg = 192
wid = 144

In [None]:
datapath = '/mnt/shared_a/OCMR/OCMR_fully_sampled_images/'
ncfiles = list([])
for file in os.listdir(datapath):
    if file.endswith(".pt"):
        ncfiles.append(file)
loader = ocmrLoader(ncfiles)

In [None]:
class RL_trainer_REINFORCE():
    def __init__(self,dataloader,policy,memory,trajectories:int=100,eps:float=1e-3,
                 fulldim:int=144,base:int=10,budget:int=50):
        self.dataloader = dataloader
        self.dataloader.reset()
        
        self.policy   = policy
        self.memory   = memory
        self.trajectories = trajectories
        self.epi = 0
        self.fulldim = fulldim
        self.base = base
        self.budget = budget
        self.eps = eps
        self.training_record = {'loss':[],'grad_norm':[],'q_values_mean':[],'q_values_std':[]}
        self.steps = 0
        self.horizon = 500
        

    # REINFORCE algorithm
    def REINFORCE(self):

        # for the number of trajectories to be generated
        for i in range(self.trajectories):

            # create empty list to store trajectory called transitions
            transitions = []

            # generate trajectory using current policy
            # for the number of time steps in trajectory
            for j in range(self.horizon):
                data_source, data_target = self.dataloader.load()
                mask_RL   = copy.deepcopy(mask)
                # epsilon = _get_epsilon(steps_epsilon, self.options)
                curr_obs = fft_observe(data_source,mask_RL)
                # generate list of action probabilities
                # select action using the probabilities
                action   = self.policy.get_action(data_source, mask=mask_RL, eps_threshold=self.eps)
                # take a step (transition to new obs)
                next_obs, reward = self.policy.step(action, data_target, mask_RL)
                # append the (prev obs, action, reward) tuple to transitions
                transitions.push(curr_obs, action, reward)
                mask = copy.deepcopy(mask_RL)

            # save all the rewards in a list
            reward_batch = torch.Tensor([r for (o,a,r) in transitions]).flip(dims=(0,)) 

            # calculate and store G values using nested loop
            batch_Gvals = []
            for i in range(self.horizon):
                new_Gval=0
                power=0
                for j in range(i,self.horizon):
                    new_Gval=new_Gval+((self.gamma**power)*reward_batch[j]).numpy()
                    power+=1
                batch_Gvals.append(new_Gval)

            # generate list of probabilities of all the actions taken in trajectory
            expected_returns_batch=torch.FloatTensor(batch_Gvals)
            expected_returns_batch /= expected_returns_batch.max()

            obs_batch = torch.Tensor([o for (o,a,r) in transitions])
            # action_batch = torch.Tensor([a for (o,a,r) in transitions])

            prob_batch = torch.nn.Sigmoid(self.model(obs_batch)) # need to pass in a mask other than none?
            # prob_batch = pred_batch.gather(dim=1,index=action_batch.long().view(-1,1)).squeeze()

            # calculate loss for gradient ascent
            loss = - torch.sum(torch.log(prob_batch) * expected_returns_batch)
            # perform gradient ascent
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
        return mask


: 

In [None]:
memory  = ReplayMemory(capacity=memory_len,
                       curr_obs_shape=(t_backtrack,heg,wid),
                       mask_shape=(wid),
                       next_obs_shape=(1,heg,wid),
                       batch_size=2,
                       burn_in=2)
model   = poly_net(samp_dim=wid)
policy  = REINFORCE(model,memory)
trainer = RL_trainer_REINFORCE(loader,policy,memory)

In [None]:
import dqn
reload(dqn)
from dqn import DQN

import replay_buffer
reload(replay_buffer)
from replay_buffer import *

import utils
reload(utils)
from utils import *

In [None]:
rein_mask = trainer.REINFORCE()

In [None]:
# generate random mask
fulldim = 144
base = 10
rand_mask = mask_naiveRand(fulldim,fix=base,other=0,roll=False)

# perform image reconstruction using random mask and mask generated by reinforce
data_source, data_target = loader.load()
rand_obs_freq = fft_observe(data_target,rand_mask,return_opt='freq')
img_recon_rand = sigpy_solver(rand_obs_freq, 
                        heg=data_target.shape[2],wid=data_target.shape[3])

rein_obs_freq = fft_observe(data_target,rein_mask,return_opt='freq')
img_recon_rein = sigpy_solver(rein_obs_freq, 
                        heg=data_target.shape[2],wid=data_target.shape[3])

# compare the 2 reconstructions
nrmse = NRMSE(img_recon_rand, img_recon_rein)
nrmse