In [1]:
import multiprocessing
import random
from functools import partial

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from src.environments import WirelessCommunicationsEnv
from src.utils import Discretizer
from src.sampler import PendulumTrajectorySampler, EpsilonGreedyPendulumTrajectorySampler
from src.trainer import QNetworkTrainer, QNetworkTester
from src.models import PARAFAC

In [2]:
ts = [5, 10, 15, 20]

envs = [WirelessCommunicationsEnv(
    T=1_000,
    K=3,
    snr_max=10,
    snr_min=2,
    snr_autocorr=0.7,
    P_occ=np.array(
        [  
            [0.4, 0.6],
            [0.6, 0.4],
        ]
    ),
    occ_initial=[1, 1, 1],
    batt_harvest=1.0, 
    P_harvest=0.2, 
    batt_initial=5,
    batt_max_capacity=10,  
    batt_weight=1.0, 
    queue_initial=10,
    queue_arrival=5,
    queue_max_capacity=20,
    t_queue_arrival=ts[i],
    queue_weight=0.2,
    loss_busy=0.8,  
) for i in range(len(ts))]

In [3]:
discretizer = Discretizer(
    min_points_states=[0, 0, 0, 0, 0, 0, 0, 0],
    max_points_states=[20, 20, 10, 1, 1, 1, 20, 10],
    bucket_states=[10, 10, 10, 2, 2, 2, 10, 10],
    min_points_actions=[0, 0, 0],
    max_points_actions=[2, 2, 2],
    bucket_actions=[10, 10, 10],
)

In [4]:
nS = [10, 10, 10, 2, 2, 2, 10, 10]
nA = [10, 10, 10]
nT = 4
gamma = 0.99

In [39]:
def create_target(states_next, rewards, Q, tasks=None):
    if tasks is not None:
        idx_target = torch.cat((tasks.unsqueeze(1), states_next), dim=1)
    else:
        idx_target = states_next

    with torch.no_grad():
        q_target = rewards + gamma * Q(idx_target).max(dim=1).values

    return q_target

def create_idx_hat(states, actions, tasks=None):
    if tasks is not None:
        idx_hat = torch.cat((tasks.unsqueeze(1), states, actions), dim=1)
    else:
        idx_hat = torch.cat((states, actions), dim=1)
    return idx_hat

def update_model(s_idx, sp_idx, a_idx, r, Q, opt, tasks=None):
    for factor in Q.factors:
        q_target = create_target(sp_idx, r, Q, tasks)
        idx_hat = create_idx_hat(s_idx, a_idx, tasks)
        q_hat = Q(idx_hat)

        opt.zero_grad()
        loss = torch.nn.MSELoss()(q_hat, q_target)
        loss.backward()

        with torch.no_grad():
            for frozen_factor in Q.factors:
                if frozen_factor is not factor:
                    frozen_factor.grad = None

        opt.step()

def select_random_action() -> np.ndarray:
        a_idx = tuple(np.random.randint(discretizer.bucket_actions).tolist())
        return discretizer.get_action_from_index(a_idx), a_idx

def select_greedy_action(Q, s: np.ndarray) -> np.ndarray:
    with torch.no_grad():
        s_idx = np.concatenate([discretizer.get_state_index(s)])
        a_idx_flat = Q(s_idx).argmax().detach().item()
        a_idx = np.unravel_index(a_idx_flat, discretizer.bucket_actions)
        return discretizer.get_action_from_index(a_idx), a_idx

def select_action(Q, s: np.ndarray, epsilon: float) -> np.ndarray:
    if np.random.rand() < epsilon:
        return select_random_action()
    return select_greedy_action(Q, s)

# Mono-task

In [40]:
E = 1000
H = 1000
lr = 0.01
eps = 1.0
eps_decay = 0.99999
eps_min = 0.1

k = 20
n_upd = nT

env_id = 1

def run_test_episode(Q, env_idx):
    with torch.no_grad():
        G = 0
        s, _ = envs[env_idx].reset()
        s_idx = torch.tensor(discretizer.get_state_index(s)).unsqueeze(0)
        for h in range(H):
            a_idx = Q(s_idx).argmax().item()
            a = discretizer.get_action_from_index(a_idx)
            a_idx = torch.tensor(a_idx).unsqueeze(0)
            sp, r, d, _, _ = envs[env_idx].step(a)
            sp_idx = torch.tensor(discretizer.get_state_index(sp)).unsqueeze(0)

            G += r

            if d:
                break

            s = sp
            s_idx = sp_idx
    return G

In [41]:
Gs = []
Q = PARAFAC(dims=nS + nA, k=k, scale=0.1)
opt = torch.optim.Adamax(Q.parameters(), lr=lr)
ds = 0
for episode in range(E):
    s, _ = envs[env_id].reset()
    s_idx = torch.tensor(discretizer.get_state_index(s)).unsqueeze(0)
    for h in range(H):
        a, a_idx = select_action(Q, s_idx, eps)
        a_idx = torch.tensor(a_idx).unsqueeze(0)
        sp, r, d, _, _ = envs[env_id].step(a)
        sp_idx = torch.tensor(discretizer.get_state_index(sp)).unsqueeze(0)

        for _ in range(n_upd):
            update_model(s_idx, sp_idx, a_idx, r, Q, opt)

        s = sp
        s_idx = sp_idx
        eps = max(eps*eps_decay, eps_min)

        # if h % 10 == 0:
    G = run_test_episode(Q, env_id)
    Gs.append(G)
    print(f"\rEpoch: {episode} - Return: {G} - {eps}", end="")
    

Epoch: 5 - Return: -128.49145094824613 - 0.9940179342332267

AttributeError: 'Tensor' object has no attribute 'astype'

In [37]:
s, _ = envs[0].reset()
s_idx = torch.tensor(discretizer.get_state_index(s)).unsqueeze(0)
a, a_idx = select_action(Q, s_idx, eps)
a_idx = torch.tensor(a_idx).unsqueeze(0)
idx = s_idx
idx = torch.cat((s_idx, a_idx), dim=1)
q = Q(idx)

In [38]:
q

tensor([-8.2449e-12], dtype=torch.float64, grad_fn=<SumBackward1>)