In [1]:
from __future__ import annotations
from itertools import count
import datetime

from tqdm import tqdm

import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import gym_envs
import gymnasium as gym
from gymnasium.wrappers import TransformReward, TransformObservation

import utils
from models import *
from agent import Agent
from strategy import DeterministicStrategy

transform = transforms.Compose([ 
    transforms.ToTensor() 
]) 

In [2]:
device = torch.device("cpu")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


gym_params = utils.GymParams(
    patch_size=(128, 128),
    resize_thumbnail=(512,512),
    max_episode_steps=1000,
)

env = gym.make(
    "gym_envs/WSIWorldEnv-v1",
    wsi_path="test.ndpi",
    render_mode="human",
    patch_size=gym_params.patch_size,
    resize_thumbnail=gym_params.resize_thumbnail,
    max_episode_steps=gym_params.max_episode_steps,
)

transform = transforms.Compose([transforms.ToTensor()])
env = TransformReward(env, lambda r: torch.tensor([r]).to(device))
env = TransformObservation(
    env,
    lambda obs: (
        transform(obs["current_view"]).unsqueeze(0).to(device),
        transform(obs["birdeye_view"]).unsqueeze(0).to(device),
        torch.tensor(obs["level"]).unsqueeze(0).to(device),
        torch.tensor(obs["p_coords"]).unsqueeze(0).to(device),
        torch.tensor(obs["b_rect"]).unsqueeze(0).to(device),
    ),
)

patch_size = env.unwrapped.wsi_wrapper.patch_size
thumbnail_size = env.unwrapped.wsi_wrapper.thumbnail_size
num_actions = env.action_space.n

Started in position:  (42871, 21953)


In [3]:
strategy = DeterministicStrategy()
agent = Agent(strategy, 6, device)
policy_net = CNN_LSTM(patch_size, thumbnail_size, num_actions).to(device)
policy_net.load_state_dict(torch.load("../DQN_b64_m2500_pS(128, 128)_thS(512, 512)_target_net_Jan_14_24 13:25.pt"))
policy_net.eval()

CNN_LSTM(
  (patch_cnn): Sequential(
    (0): Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (thumbnail_cnn): Sequential(
    (0): Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (additional_info_fc): Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
  )
  (latent_space): Sequential(
    (0): Linear(in_features=626720, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=128, bias=True)
    (3): ReLU()
  )
  (attention): Sequential(
    (0): Linear(in_featur

In [4]:
observation, info = env.reset()

Started in position:  (34274, 32805)


  logger.warn(f"{pre} is not within the observation space.")
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


In [50]:
with torch.no_grad():
    out = policy_net.forward(torch.randn(1,3,128,128), torch.randn(1,3,512, 512), torch.tensor([[0,1,1,0]]), torch.randn(1, 2) * 100, torch.randn(1, 4) * 100)
    print(out)

tensor([[-5.9220, -6.1816, -5.7517, -5.7562, -5.9765, -5.8443]])


In [139]:
old_lstm_out = lstm_out

In [189]:
torch.where((latent_space_features == 0) == True)[0].shape

torch.Size([83])

In [193]:
with torch.no_grad():
    patch_features = policy_net.patch_cnn(torch.randn(1,3,128,128))
    thumbnail_features = policy_net.thumbnail_cnn(torch.randn(1,3,512, 512))

    # Flatten the outputs for concatenation
    patch_features = patch_features.view(patch_features.size(0), -1)
    thumbnail_features = thumbnail_features.view(thumbnail_features.size(0), -1)

    # Concatenate additional inputs after processing them through FC layers
    additional_info = torch.cat([torch.tensor([[0,1,1,0]]), torch.randn(1, 2) * 100,  torch.randn(1, 4) * 100], dim=1)
    additional_info = policy_net.additional_info_fc(additional_info)
    print('aditional_info')
    print(additional_info)

    # Concatenate all features
    combined_features = torch.cat(
        [patch_features, thumbnail_features, additional_info], dim=1
    )

    # Compression
    latent_space_features = policy_net.latent_space(combined_features)
    attention_weights = F.softmax(policy_net.attention(latent_space_features), dim=1)
    attended_features = latent_space_features * attention_weights
    
    print("latent_space_features")
    print(latent_space_features)
    
    print("attended_features")
    print(attended_features)
    
    lstm_out, _ = policy_net.lstm(
            attended_features.unsqueeze(0)
        )  # Add batch dimension if necessary

    print("lstm_out")
    print(lstm_out)
    
    print(policy_net.action_effect_fc(lstm_out.squeeze(0)))

aditional_info
tensor([[56.9997, 56.7730,  0.0000,  0.0000,  0.0000,  0.0000, 18.2573, 16.5518,
          0.0000,  0.0000,  0.0000, 59.2336,  1.8961, 23.5242,  0.0000,  0.0000,
         17.2483,  5.9375, 20.2759,  0.0000, 19.3641,  0.0000, 76.1738, 37.4877,
         25.0711, 73.8416, 57.5073,  0.0000, 35.7419, 67.5922, 92.1250, 18.8970]])
latent_space_features
tensor([[ 10806.9609, 180107.6875,      0.0000,      0.0000, 215660.2969,
              0.0000,      0.0000,      0.0000,      0.0000,      0.0000,
          78353.1797,      0.0000,      0.0000,      0.0000, 224700.3125,
              0.0000,      0.0000,      0.0000,      0.0000, 114811.2109,
          67736.3750,      0.0000,      0.0000,      0.0000,  35517.1250,
              0.0000,   3129.1111, 107823.2969,      0.0000,      0.0000,
              0.0000,  70110.2188,      0.0000,  36015.8086,  65098.7344,
              0.0000,      0.0000,      0.0000,      0.0000,      0.0000,
          41967.0234,      0.0000,  95631.796

In [8]:
env = gym.make(
    "gym_envs/WSIWorldEnv-v1",
    wsi_path="test.ndpi",
    render_mode="human",
    patch_size=gym_params.patch_size,
    resize_thumbnail=gym_params.resize_thumbnail,
    max_episode_steps=gym_params.max_episode_steps,
)

transform = transforms.Compose([transforms.ToTensor()])
env = TransformReward(env, lambda r: torch.tensor([r]).to(device))
env = TransformObservation(
    env,
    lambda obs: (
        transform(obs["current_view"]).unsqueeze(0).to(device),
        transform(obs["birdeye_view"]).unsqueeze(0).to(device),
        torch.tensor(obs["level"]).unsqueeze(0).to(device),
        torch.tensor(obs["p_coords"]).unsqueeze(0).to(device),
        torch.tensor(obs["b_rect"]).unsqueeze(0).to(device),
    ),
)

patch_size = env.unwrapped.wsi_wrapper.patch_size
thumbnail_size = env.unwrapped.wsi_wrapper.thumbnail_size
num_actions = env.action_space.n
observation, info = env.reset()

for _ in range(1000):
    # action = take_action(observation).item()
    action = agent.select_action(observation, policy_net).item()
    print(action)
    old_obs = observation
    print(action)
    observation, reward, terminated, truncated, info = env.step(action)
    if torch.all(observation[0] == old_obs[0]).item() or torch.all(observation[1] == old_obs[1]).item():
        print('mochkil')
    

    if terminated or truncated:
        break

env.close()

Started in position:  (29743, 27640)
Started in position:  (38501, 26645)
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2


KeyboardInterrupt: 

In [11]:
def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().item())
            max_grads.append(p.grad.abs().max().item())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])

    plt.show()

In [43]:
observation, info = env.reset()
loss_per_episode = 0
reward_per_episode = 0

for timestep in (pbar := tqdm(count(), total=gym_params.max_episode_steps)):
    pbar.set_description(
        f"Currently in Ep: {episode} || GPU MEM USED: {round(torch.cuda.memory_allocated() / 1e9, 2)}"
    )
    # current_view, birdeye_view, level, p_coords, b_rect

    action = agent.select_action(observation, policy_net)
    next_observation, reward, terminated, truncated, info = env.step(
        action.item()
    )
    reward_per_episode += reward

    patch, bird_view, level, p_coord, b_rect = observation
    (
        next_patch,
        next_bird_view,
        next_level,
        next_p_coord,
        next_b_rect,
    ) = next_observation

    memory.push(
        utils.RichExperience(
            patch,
            bird_view,
            action,
            level,
            p_coord,
            b_rect,
            next_patch,
            next_bird_view,
            next_level,
            next_p_coord,
            next_b_rect,
            reward,
        )
    )
    observation = next_observation

    if memory.can_provide_sample(training_params.batch_size):
        experiences = memory.sample(training_params.batch_size)
        (
            patches,
            bird_views,
            levels,
            p_coords,
            b_rects,
            actions,
            rewards,
            next_patches,
            next_bird_views,
            next_levels,
            next_p_coords,
            next_b_rects,
        ) = utils.extract_rich_experiences_tensors(experiences)
        current_q_values = QValues.get_current(
            policy_net,
            actions,
            (
                patches,
                bird_views,
                levels,
                p_coords,
                b_rects,
            ),
        )
        next_q_values = QValues.get_next(
            target_net,
            (
                next_patches,
                next_bird_views,
                next_levels,
                next_p_coords,
                next_b_rects,
            ),
        )
        target_q_values = (next_q_values * training_params.gamma) + rewards

        loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
        writer.add_scalar(f"Loss/ep{episode}_Timestep", loss, timestep)

        optimizer.zero_grad()
        loss.backward()
        plo

Started in position:  (45151, 27337)
Started in position:  (44366, 25492)


Currently in Ep: 0 || GPU MEM USED: 0.0:   0%|         | 0/1000 [00:00<?, ?it/s]

tensor([[-7.5330, -8.4948, -7.4875, -7.4680, -7.6992, -7.5620]])





NameError: name 'actions' is not defined

In [28]:
plot_grad_flow(policy_net.patch_cnn.named_modules())

AttributeError: 'Sequential' object has no attribute 'requires_grad'

In [6]:
env = gym.make(
    "gym_envs/WSIWorldEnv-v1",
    wsi_path="test.ndpi",
    render_mode="human",
    patch_size=gym_params.patch_size,
    resize_thumbnail=gym_params.resize_thumbnail,
    max_episode_steps=gym_params.max_episode_steps,
)

transform = transforms.Compose([transforms.ToTensor()])
env = TransformReward(env, lambda r: torch.tensor([r]).to(device))
env = TransformObservation(
    env,
    lambda obs: (
        transform(obs["current_view"]).unsqueeze(0).to(device),
        transform(obs["birdeye_view"]).unsqueeze(0).to(device),
        torch.tensor(obs["level"]).unsqueeze(0).to(device),
        torch.tensor(obs["p_coords"]).unsqueeze(0).to(device),
        torch.tensor(obs["b_rect"]).unsqueeze(0).to(device),
    ),
)

patch_size = env.unwrapped.wsi_wrapper.patch_size
thumbnail_size = env.unwrapped.wsi_wrapper.thumbnail_size
num_actions = env.action_space.n
observation, info = env.reset()

for _ in range(1000):
    # action = take_action(observation).item()
    action = agent.select_action(observation, policy_net).item()
    print(action)
    old_obs = observation
    print(action)
    observation, reward, terminated, truncated, info = env.step(action)
    if torch.all(observation[0] == old_obs[0]).item() or torch.all(observation[1] == old_obs[1]).item():
        print('mochkil')
    

    if terminated or truncated:
        break

env.close()

Started in position:  (39128, 32939)
Started in position:  (44969, 30495)
2


  logger.warn(f"{pre} is not within the observation space.")


(tensor([[[[0.3922, 0.3882, 0.4118,  ..., 0.7176, 0.6824, 0.6196],
          [0.5255, 0.5569, 0.4588,  ..., 0.7529, 0.6667, 0.5059],
          [0.6275, 0.5882, 0.5333,  ..., 0.7412, 0.6196, 0.4275],
          ...,
          [0.8902, 0.8706, 0.7843,  ..., 0.2980, 0.3216, 0.3020],
          [0.8824, 0.8706, 0.7961,  ..., 0.3216, 0.3412, 0.3686],
          [0.8745, 0.8667, 0.8118,  ..., 0.3725, 0.4118, 0.4471]],

         [[0.2235, 0.2118, 0.2275,  ..., 0.5922, 0.5490, 0.4863],
          [0.3137, 0.3529, 0.2510,  ..., 0.6275, 0.5412, 0.3882],
          [0.4157, 0.3765, 0.3176,  ..., 0.6157, 0.4902, 0.3059],
          ...,
          [0.8784, 0.8627, 0.7922,  ..., 0.1490, 0.1725, 0.1412],
          [0.8706, 0.8627, 0.8039,  ..., 0.1490, 0.1608, 0.1725],
          [0.8627, 0.8588, 0.8196,  ..., 0.1843, 0.2039, 0.2196]],

         [[0.5020, 0.4824, 0.4941,  ..., 0.7765, 0.7490, 0.7216],
          [0.5569, 0.6118, 0.5255,  ..., 0.8196, 0.7373, 0.5922],
          [0.6431, 0.6196, 0.5804,  ..., 

KeyboardInterrupt: 

In [7]:
env.close()