In [None]:
from header import *
from utils import *
from replay_buffer import *
from models import poly_net
from reconstructors import sigpy_solver
from dqn import DQN
from importlib import reload
memory_len = 10
t_backtrack = 3
heg = 192
wid = 144

In [None]:
datapath = '/mnt/shared_a/OCMR/OCMR_fully_sampled_images/'
ncfiles = list([])
for file in os.listdir(datapath):
    if file.endswith(".pt"):
        ncfiles.append(file)
loader = ocmrLoader(ncfiles)

In [None]:
class RL_trainer_PG():
    def __init__(self,dataloader,policy,memory,episodes:int=10,eps:float=1e-3,
                 fulldim:int=144,base:int=10,budget:int=50):
        self.dataloader = dataloader
        self.dataloader.reset()
        
        self.policy   = policy
        self.memory   = memory
        self.episodes = episodes
        self.epi = 0
        self.fulldim = fulldim
        self.base = base
        self.budget = budget
        self.eps = eps
        self.training_record = {'loss':[],'grad_norm':[],'q_values_mean':[],'q_values_std':[]}
        self.steps = 0
        self.horizon = 500
    
    def train(self):      
        # run training
        while self.epi < self.episodes:
            print(f'episode [{self.epi+1}/{self.episodes}]')
            mask = mask_naiveRand(self.fulldim,fix=self.base,other=0,roll=False)
            trajectory = []

            # generate the trajectory
            for i in range(self.horizon):
                data_source, data_target = self.dataloader.load()
                mask_RL   = copy.deepcopy(mask)
#                 epsilon = _get_epsilon(steps_epsilon, self.options)
                curr_obs = fft_observe(data_source,mask_RL)
                action   = self.policy.get_action(data_source, mask=mask_RL, eps_threshold=self.eps)
                next_obs, reward = self.policy.step(action, data_target, mask_RL)
                trajectory.push(curr_obs, action, reward)
                mask = copy.deepcopy(mask_RL)
            
            reward_batch = torch.Tensor([r for (o,a,r) in trajectory]).flip(dims=(0,)) 

            # compute the G values
            batch_Gvals = []
            for i in range(self.horizon):
                new_Gval=0
                power=0
                for j in range(i,self.horizon):
                    new_Gval=new_Gval+((self.gamma**power)*reward_batch[j]).numpy()
                    power+=1
                batch_Gvals.append(new_Gval)

            expected_returns_batch=torch.FloatTensor(batch_Gvals)
            expected_returns_batch /= expected_returns_batch.max()

            obs_batch = torch.Tensor([o for (o,a,r) in trajectory])
            action_batch = torch.Tensor([a for (o,a,r) in trajectory])

            pred_batch = self.model(obs_batch)
            prob_batch = pred_batch.gather(dim=1,index=action_batch.long().view(-1,1)).squeeze() 
            
            # perform gradient ascent
            loss = - torch.sum(torch.log(prob_batch) * expected_returns_batch)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


            # one mask at a time, start with a low frequency mask
            while mask.sum() < self.budget + self.base:
                self.steps += 1
#                 print(f'step: {self.steps}, beginning, mask sum: {mask.sum().item()}')
                data_source, data_target = self.dataloader.load()
                mask_RL   = copy.deepcopy(mask)
                mask_rand = copy.deepcopy(mask)
#                 epsilon = _get_epsilon(steps_epsilon, self.options)
                curr_obs = fft_observe(data_source,mask_RL)
                action   = self.policy.get_action(data_source, mask=mask_RL, eps_threshold=self.eps)
                next_obs, reward = self.policy.step(action, data_target, mask_RL)
#                 print(f'step: {self.steps}, policy.step, mask_RL sum: {mask_RL.sum().item()}')
                
                self.memory.push(curr_obs, mask, action, next_obs, reward)
                mask = copy.deepcopy(mask_RL)
#                 print(f'step: {self.steps}, assign, mask sum: {mask.sum().item()}')
                
                ### compare with random policy
                with torch.no_grad():
                    action_rand = self.policy.get_rand_action(mask=mask_rand)
                    _, reward_rand = self.policy.step(action_rand, data_target, mask_rand)
                ###
                
                update_results = self.policy.update_parameters()
                if update_results is not None:
                    for key in self.training_record.keys():
                        self.training_record[key].append(update_results[key])
                    curr_loss = update_results['loss']
                    print(f'step: {self.steps}, loss: {curr_loss:.4f}, RL reward: {reward.mean().item():.4f}, Rand reward: {reward_rand.mean().item():.4f} \n mask sum: {mask.sum().item()}')
                    torch.cuda.empty_cache()
                else:
                    print(f'step: {self.steps}, burn in, mask sum: {mask.sum().item()}')
                
#                 if self.steps % self.options.target_net_update_freq == 0:
#                     self.logger.info("Updating target network.")
#                     self.target_net.load_state_dict(self.policy.state_dict())
            self.dataloader.reset()
            self.epi += 1

In [None]:
import reinforce
reload(reinforce)
from reinforce import REINFORCE

import replay_buffer
reload(replay_buffer)
from replay_buffer import *

import utils
reload(utils)
from utils import *

In [None]:
memory  = ReplayMemory(capacity=memory_len,
                       curr_obs_shape=(t_backtrack,heg,wid),
                       mask_shape=(wid),
                       next_obs_shape=(1,heg,wid),
                       batch_size=2,
                       burn_in=2)
model   = poly_net(samp_dim=wid)
policy  = REINFORCE(model,memory)
trainer = RL_trainer_PG(loader,policy,memory)

In [None]:
trainer.train()