In [None]:
cd drive/MyDrive/JAIST/Research/RL/task-grouping

In [None]:
!make style

In [None]:
from src.environment.ac_control import ACControl
from src.environment.interaction_buffer import Buffer
import numpy as np

In [None]:
class Agent:
    def __init__(self):
        self.cold_dist = [0.05, 0.05, 0.05, 0.05, 0.1, 0.2, 0.5]
        self.quit_cold_dist = [0.05, 0.05, 0.05, 0.1, 0.15, 0.2, 0.4]
        self.bit_cold_dist = [0.05, 0.05, 0.1, 0.1, 0.3, 0.2, 0.2]
        self.bit_hot_dist = [i for i in reversed(self.bit_cold_dist)]
        self.quit_hot_dist = [i for i in reversed(self.quit_cold_dist)]
        self.hot_dist = [i for i in reversed(self.cold_dist)]

    def get_action(self, tmp):
        dist = self._get_dist(tmp)
        action = np.random.choice(len(dist), 1, p=dist)
        return action, dist
    
    def _get_dist(self, tmp):

        if tmp in np.arange(0, 10):
            return self.cold_dist

        if tmp in np.arange(10, 20):
            return self.quit_cold_dist

        if tmp in np.arange(20, 25):
            return self.bit_cold_dist

        if tmp in np.arange(25, 30):
            return self.bit_hot_dist

        if tmp in np.arange(30, 40):
            return self.quit_hot_dist

        if tmp in np.arange(40, 50):
            return self.hot_dist


In [None]:
env = ACControl()
save = Buffer()
agent = Agent()

ID = 1
TRIAL_LEN = 500_000
COLUMNS = ['ID', 'State', 'Action', 'Reward', 'Next_state', 'Prob']

observation = env.reset()

for time in range(TRIAL_LEN):
    
    action, dist = agent.get_action(observation)

    next_observation, reward = env.step(action.item())
    save.add(ID, observation, action.item(), reward, next_observation, dist[action.item()])
    observation = next_observation

env.close()

log = save.get_df(COLUMNS)

In [None]:
import torch

from configs.config import CFG_DICT
from src.model.cvae import CVAE
from src.dataset.cvae_dataset import get_cvae_dataloader
from src.trainer.cvae_trainer import fit
from src.utils.save import save_config, save_src

save_config()
save_src()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CFG_DICT["TRAIN"]["LR"])
train_loader, test_loader = get_cvae_dataloader(log)

fit(model=model, dataloaders=(train_loader, test_loader), optimizer=optimizer, device=device)

In [None]:
model_path = 'logs/CVAE/cvae_0922_500_000/best_state.pkl'
model.load_state_dict(torch.load(model_path)['model_state'])

In [None]:
from src.evaluation.metrics import KL_divergence, total_variation_distance

metric_log = np.zeros((len(test_loader.dataset), 3))

log_idx = 0

model.eval()
for data in test_loader:
    Xs, Ys = data
    for idx in np.arange(len(Xs)):
        
        action_idx = np.argmax(data[0][idx]).item()
        metric_log[log_idx, 0] = action_idx
        real_dist = agent._get_dist(action_idx)

        X = torch.unsqueeze(data[0][idx], 0).to(device, dtype=torch.float)
        Y = torch.unsqueeze(data[1][idx], 0).to(device, dtype=torch.float)
        _, mu, logvar, Y_hat = model(X, Y)
        estimate_dist = Y_hat.to('cpu').detach().numpy().copy()[0]

        variation_distance = total_variation_distance(real_dist, estimate_dist)
        kl = KL_divergence(real_dist, estimate_dist)
        metric_log[log_idx, 1] = variation_distance
        metric_log[log_idx, 2] = kl


        log_idx += 1