In [47]:
from __future__ import annotations
from itertools import count
import datetime

from tqdm import tqdm

import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import gym_envs
import gymnasium as gym
from gymnasium.wrappers import TransformReward, TransformObservation

import utils
from models import *
from agent import Agent
from strategy import DeterministicStrategy

transform = transforms.Compose([ 
    transforms.ToTensor() 
]) 

In [48]:
device = torch.device("cpu")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


gym_params = utils.GymParams(
    patch_size=(128, 128),
    resize_thumbnail=(512,512),
    max_episode_steps=1000,
)

env = gym.make(
    "gym_envs/WSIWorldEnv-v1",
    wsi_path="test.ndpi",
    render_mode="human",
    patch_size=gym_params.patch_size,
    resize_thumbnail=gym_params.resize_thumbnail,
    max_episode_steps=gym_params.max_episode_steps,
)

transform = transforms.Compose([transforms.ToTensor()])
env = TransformReward(env, lambda r: torch.tensor([r]).to(device))
env = TransformObservation(
    env,
    lambda obs: (
        transform(obs["current_view"]).unsqueeze(0).to(device),
        transform(obs["birdeye_view"]).unsqueeze(0).to(device),
        torch.tensor(obs["level"]).unsqueeze(0).to(device),
        torch.tensor(obs["p_coords"]).unsqueeze(0).to(device),
        torch.tensor(obs["b_rect"]).unsqueeze(0).to(device),
    ),
)

patch_size = env.unwrapped.wsi_wrapper.patch_size
thumbnail_size = env.unwrapped.wsi_wrapper.thumbnail_size
num_actions = env.action_space.n

Started in position:  (31385, 33813)


In [49]:
strategy = DeterministicStrategy()
agent = Agent(strategy, 6, device)
# policy_net = CNN_LSTM(patch_size, thumbnail_size, num_actions).to(device)
# policy_net.load_state_dict(torch.load("../cnn_lstm_weights/DQN_b128_m1000_pS(128, 128)_thS(512, 512)_target_net_Jan_14_24 21:12.pt", map_location=device))
policy_net = CNN_Attention(patch_size, thumbnail_size, num_actions).to(device)
policy_net.load_state_dict(torch.load("../b_att.pt", map_location=device))
policy_net.eval()

CNN_Attention(
  (patch_cnn): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (thumbnail_cnn): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (additional_info_fc): Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0.3, inplace=False)
  )
  (latent_space_no_bn): Sequential(
    (0): Linear(in_features=417824, out_features=1024, bias=True)
    (1): ReLU()
    (2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True

In [50]:
observation, info = env.reset()

Started in position:  (49711, 15694)


In [51]:
with torch.no_grad():
    out = policy_net.forward(torch.randn(4,3,128,128), torch.randn(4,3,512, 512), torch.tensor([[0] * 4]*4), torch.randn(4, 2) * 100, torch.randn(4, 4) * 100)
    print(out)

tensor([[ 6.2689,  7.0238,  7.0049,  4.0556, 10.1276,  0.1714],
        [ 6.3239,  7.0773,  7.0867,  4.1056, 10.2109,  0.1955],
        [ 6.3908,  7.1424,  7.1862,  4.1666, 10.3124,  0.2250],
        [ 6.3042,  7.0581,  7.0574,  4.0877, 10.1810,  0.1869]])


In [52]:
with torch.no_grad():
    patch_features = policy_net.patch_cnn(torch.randn(1,3,128,128))
    thumbnail_features = policy_net.thumbnail_cnn(torch.randn(1,3,512, 512))

    # Flatten the outputs for concatenation
    patch_features = patch_features.view(patch_features.size(0), -1)
    thumbnail_features = thumbnail_features.view(thumbnail_features.size(0), -1)

    # Concatenate additional inputs after processing them through FC layers
    additional_info = torch.cat([torch.tensor([[0,1,1,0]]), torch.randn(1, 2) * 100,  torch.randn(1, 4) * 100], dim=1)
    additional_info = policy_net.additional_info_fc(additional_info)
    print('aditional_info')
    print(additional_info)

    # Concatenate all features
    combined_features = torch.cat(
        [patch_features, thumbnail_features, additional_info], dim=1
    )

    # Compression
    latent_space_features = policy_net.latent_space_no_bn(combined_features)
    attention_weights = F.softmax(policy_net.attention(latent_space_features), dim=1)
    attended_features = latent_space_features * attention_weights
    
    print("latent_space_features")
    print(latent_space_features)
    
    print("attended_features")
    print(attended_features)
    
    print(policy_net.action_fc(attended_features.squeeze(0)))

aditional_info
tensor([[-1.3554, -0.4963, -2.1428, -1.2743,  5.0278, -0.3955,  3.4138, -0.2054,
          1.0209, -1.0970,  4.3079, -1.6001, -0.1234, -0.1014, -1.0252, -2.5816,
         -1.3156, -1.6905, -0.8082, -0.9889, -0.4099, -0.9202, -0.2380, -0.3894,
         -0.2621, -0.7873, -0.4122, -0.1399, -0.2182, -0.2966, -0.8033, -0.7458]])
latent_space_features
tensor([[-1.2964e+00, -4.3772e-01, -6.3094e-01, -6.5595e-01, -1.3615e+00,
         -1.3278e+00, -1.1776e+00, -5.4944e-01,  1.1108e+00, -8.4265e-01,
         -3.5220e-01, -1.9605e+00,  5.2327e-01, -2.4344e-03, -5.2522e-02,
         -2.8179e+00, -1.3112e+00, -1.6990e-01,  6.3202e-01, -1.0903e+00,
         -1.0963e-01,  2.8595e-01, -7.5715e-01, -7.2245e-01, -2.1238e-01,
         -2.3124e+00,  1.1346e-01, -2.3038e-01, -1.2847e+00,  9.6772e-01,
          2.5876e-01,  8.7305e-01, -6.5820e-01, -7.7302e-01,  3.3640e-01,
          1.4067e+00,  2.0990e-01, -1.6126e+00, -1.8542e+00, -1.3336e+00,
         -2.4878e+00, -1.1591e+00, -1.9563e+0

In [60]:
env = gym.make(
    "gym_envs/WSIWorldEnv-v1",
    wsi_path="test.ndpi",
    render_mode="human",
    patch_size=gym_params.patch_size,
    resize_thumbnail=gym_params.resize_thumbnail,
    max_episode_steps=gym_params.max_episode_steps,
)

transform = transforms.Compose([transforms.ToTensor()])
env = TransformReward(env, lambda r: torch.tensor([r]).to(device))
env = TransformObservation(
    env,
    lambda obs: (
        transform(obs["current_view"]).unsqueeze(0).to(device),
        transform(obs["birdeye_view"]).unsqueeze(0).to(device),
        torch.tensor(obs["level"]).unsqueeze(0).to(device),
        torch.tensor(obs["p_coords"]).unsqueeze(0).to(device),
        torch.tensor(obs["b_rect"]).unsqueeze(0).to(device),
    ),
)

patch_size = env.unwrapped.wsi_wrapper.patch_size
thumbnail_size = env.unwrapped.wsi_wrapper.thumbnail_size
num_actions = env.action_space.n
observation, info = env.reset()

for _ in range(1000):
    # action = take_action(observation).item()
    action = agent.select_action(observation, policy_net).item()
    old_obs = observation
    observation, reward, terminated, truncated, info = env.step(action)
    if torch.all(observation[0] == old_obs[0]).item() or torch.all(observation[1] == old_obs[1]).item():
        print('mochkil')

    print(action, " reward -> ", reward)
    

    if terminated or truncated:
        break

env.close()

Started in position:  (28840, 28017)
Started in position:  (34636, 31043)
1  reward ->  tensor([-0.9531])
1  reward ->  tensor([-0.9979])
1  reward ->  tensor([-0.9850])
1  reward ->  tensor([-0.9879])
1  reward ->  tensor([-0.9705])
1  reward ->  tensor([-0.9100])
1  reward ->  tensor([-0.9264])
1  reward ->  tensor([-0.9619])
1  reward ->  tensor([-0.9630])
1  reward ->  tensor([-0.9703])
1  reward ->  tensor([-0.9790])
1  reward ->  tensor([-0.9822])
1  reward ->  tensor([-0.9346])
1  reward ->  tensor([-0.9000])
1  reward ->  tensor([-0.9240])
1  reward ->  tensor([-0.9081])
1  reward ->  tensor([-0.9786])
1  reward ->  tensor([-0.9783])
1  reward ->  tensor([-0.9742])
1  reward ->  tensor([-0.9619])
1  reward ->  tensor([-0.9197])
1  reward ->  tensor([-0.9112])
1  reward ->  tensor([-0.9423])
1  reward ->  tensor([-0.9293])
1  reward ->  tensor([-0.8532])
1  reward ->  tensor([-0.9858])
1  reward ->  tensor([-0.9821])
1  reward ->  tensor([-0.9790])
1  reward ->  tensor([-0.9858]

In [11]:
def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().item())
            max_grads.append(p.grad.abs().max().item())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])

    plt.show()

In [43]:
observation, info = env.reset()
loss_per_episode = 0
reward_per_episode = 0

for timestep in (pbar := tqdm(count(), total=gym_params.max_episode_steps)):
    pbar.set_description(
        f"Currently in Ep: {episode} || GPU MEM USED: {round(torch.cuda.memory_allocated() / 1e9, 2)}"
    )
    # current_view, birdeye_view, level, p_coords, b_rect

    action = agent.select_action(observation, policy_net)
    next_observation, reward, terminated, truncated, info = env.step(
        action.item()
    )
    reward_per_episode += reward

    patch, bird_view, level, p_coord, b_rect = observation
    (
        next_patch,
        next_bird_view,
        next_level,
        next_p_coord,
        next_b_rect,
    ) = next_observation

    memory.push(
        utils.RichExperience(
            patch,
            bird_view,
            action,
            level,
            p_coord,
            b_rect,
            next_patch,
            next_bird_view,
            next_level,
            next_p_coord,
            next_b_rect,
            reward,
        )
    )
    observation = next_observation

    if memory.can_provide_sample(training_params.batch_size):
        experiences = memory.sample(training_params.batch_size)
        (
            patches,
            bird_views,
            levels,
            p_coords,
            b_rects,
            actions,
            rewards,
            next_patches,
            next_bird_views,
            next_levels,
            next_p_coords,
            next_b_rects,
        ) = utils.extract_rich_experiences_tensors(experiences)
        current_q_values = QValues.get_current(
            policy_net,
            actions,
            (
                patches,
                bird_views,
                levels,
                p_coords,
                b_rects,
            ),
        )
        next_q_values = QValues.get_next(
            target_net,
            (
                next_patches,
                next_bird_views,
                next_levels,
                next_p_coords,
                next_b_rects,
            ),
        )
        target_q_values = (next_q_values * training_params.gamma) + rewards

        loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
        writer.add_scalar(f"Loss/ep{episode}_Timestep", loss, timestep)

        optimizer.zero_grad()
        loss.backward()
        plo

Started in position:  (45151, 27337)
Started in position:  (44366, 25492)


Currently in Ep: 0 || GPU MEM USED: 0.0:   0%|         | 0/1000 [00:00<?, ?it/s]

tensor([[-7.5330, -8.4948, -7.4875, -7.4680, -7.6992, -7.5620]])





NameError: name 'actions' is not defined

In [28]:
plot_grad_flow(policy_net.patch_cnn.named_modules())

AttributeError: 'Sequential' object has no attribute 'requires_grad'

In [6]:
env = gym.make(
    "gym_envs/WSIWorldEnv-v1",
    wsi_path="test.ndpi",
    render_mode="human",
    patch_size=gym_params.patch_size,
    resize_thumbnail=gym_params.resize_thumbnail,
    max_episode_steps=gym_params.max_episode_steps,
)

transform = transforms.Compose([transforms.ToTensor()])
env = TransformReward(env, lambda r: torch.tensor([r]).to(device))
env = TransformObservation(
    env,
    lambda obs: (
        transform(obs["current_view"]).unsqueeze(0).to(device),
        transform(obs["birdeye_view"]).unsqueeze(0).to(device),
        torch.tensor(obs["level"]).unsqueeze(0).to(device),
        torch.tensor(obs["p_coords"]).unsqueeze(0).to(device),
        torch.tensor(obs["b_rect"]).unsqueeze(0).to(device),
    ),
)

patch_size = env.unwrapped.wsi_wrapper.patch_size
thumbnail_size = env.unwrapped.wsi_wrapper.thumbnail_size
num_actions = env.action_space.n
observation, info = env.reset()

for _ in range(1000):
    # action = take_action(observation).item()
    action = agent.select_action(observation, policy_net).item()
    print(action)
    old_obs = observation
    print(action)
    observation, reward, terminated, truncated, info = env.step(action)
    if torch.all(observation[0] == old_obs[0]).item() or torch.all(observation[1] == old_obs[1]).item():
        print('mochkil')
    

    if terminated or truncated:
        break

env.close()

Started in position:  (39128, 32939)
Started in position:  (44969, 30495)
2


  logger.warn(f"{pre} is not within the observation space.")


(tensor([[[[0.3922, 0.3882, 0.4118,  ..., 0.7176, 0.6824, 0.6196],
          [0.5255, 0.5569, 0.4588,  ..., 0.7529, 0.6667, 0.5059],
          [0.6275, 0.5882, 0.5333,  ..., 0.7412, 0.6196, 0.4275],
          ...,
          [0.8902, 0.8706, 0.7843,  ..., 0.2980, 0.3216, 0.3020],
          [0.8824, 0.8706, 0.7961,  ..., 0.3216, 0.3412, 0.3686],
          [0.8745, 0.8667, 0.8118,  ..., 0.3725, 0.4118, 0.4471]],

         [[0.2235, 0.2118, 0.2275,  ..., 0.5922, 0.5490, 0.4863],
          [0.3137, 0.3529, 0.2510,  ..., 0.6275, 0.5412, 0.3882],
          [0.4157, 0.3765, 0.3176,  ..., 0.6157, 0.4902, 0.3059],
          ...,
          [0.8784, 0.8627, 0.7922,  ..., 0.1490, 0.1725, 0.1412],
          [0.8706, 0.8627, 0.8039,  ..., 0.1490, 0.1608, 0.1725],
          [0.8627, 0.8588, 0.8196,  ..., 0.1843, 0.2039, 0.2196]],

         [[0.5020, 0.4824, 0.4941,  ..., 0.7765, 0.7490, 0.7216],
          [0.5569, 0.6118, 0.5255,  ..., 0.8196, 0.7373, 0.5922],
          [0.6431, 0.6196, 0.5804,  ..., 

KeyboardInterrupt: 

In [7]:
env.close()