In [1]:
ckpt_name = 'pretrain.spirl_v5_skill3_z10_1kiq585w_simple_c_v2_B0200_20_C_5000.pt'
skill_length = 3
frame_stack = 5
resize_res = 28
is_quantized = True
num_level = 0
# thresholds = list()
thresholds = [10 * i for i in range(0, 20)]
pos = False

In [2]:
from collections import deque
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions as torch_dist
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp

from tqdm.notebook import trange
import wandb

from vit_pytorch import ViT
# modified output ViT; from (batch, num_class) to (batch, dim)

# from simpl.collector import ConcurrentCollector, TimeLimitCollector, GPUWorker, Buffer
# from simpl.nn import itemize
# from simpl.math import discount
# from simpl.rl.policy import ContextTruncatedNormalMLPPolicy
# from simpl.rl.qf import MLPQF

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import sys
sys.path.append('C:\\Users\\Lee Geonju\\Desktop\\drone\\SiMPL')

In [4]:
import torch.nn as nn
from simpl.nn import MLP

from simpl.nn import ToDeviceMixin
import torch.distributions as torch_dist

from simpl.math import inverse_softplus, inverse_sigmoid

class SkillEncoder(ToDeviceMixin, nn.Module):
    def __init__(self, action_dim, z_dim, hidden_dim, n_lstm, n_mlp_hidden):
        super().__init__()
        
        self.action_dim = action_dim
        
        self.lstm = nn.LSTM(
            action_dim,
            hidden_dim, n_lstm, batch_first=True
        )
        self.mlp = MLP([hidden_dim]*n_mlp_hidden + [2*z_dim], 'relu')
        
        self.register_buffer('prior_loc', torch.zeros(z_dim))
        self.register_buffer('prior_scale', torch.ones(z_dim))
        # self.register_buffer('h0', torch.zeros(n_lstm, hidden_dim))
        # self.register_buffer('c0', torch.zeros(n_lstm, hidden_dim))

    @property
    def prior_dist(self):
        return torch_dist.Independent(torch_dist.Normal(self.prior_loc, self.prior_scale), 1)
    
        
    def dist(self, batch_seq_action):
        # batch_h0 = self.h0[:, None, :].expand(-1, len(batch_seq_state), -1)
        # batch_c0 = self.c0[:, None, :].expand(-1, len(batch_seq_state), -1)
        
        batch_seq_onehot_action = F.one_hot(batch_seq_action, num_classes=self.action_dim).float()
        batch_seq_out, _ = self.lstm(batch_seq_onehot_action)
        batch_last_out = batch_seq_out[:, -1, :]
        batch_loc, batch_pre_scale = self.mlp(batch_last_out).chunk(2, dim=-1)
        batch_scale = F.softplus(batch_pre_scale)
        
        return torch_dist.Independent(
            torch_dist.Normal(batch_loc, batch_scale)
        , 1)

    
class PriorPolicy(ToDeviceMixin, nn.Module):
    def __init__(self, state_shape, z_dim, hidden_dim, n_hidden,
                 min_scale=0.001, max_scale=None, init_scale=0.1):
        super().__init__()
        
#         assert state_shape == (5, 84, 84)
#         assert hidden_dim == 128
        
        self.z_dim = z_dim
        if resize_res == 84:
            self.conv_net = nn.Sequential(
                nn.Conv2d(frame_stack, 32, kernel_size=4, stride=3),
                nn.ReLU(),
                nn.Conv2d(32, 64, kernel_size=4, stride=3),
                nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=4, stride=2),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, stride=1),
            )
            
        elif resize_res == 28:
            self.conv_net = nn.Sequential(
                nn.Conv2d(frame_stack, 64, kernel_size=4, stride=3),
                nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=4, stride=2),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, stride=1),
            )
            
        self.mlp = MLP([128]*n_hidden + [2*z_dim], 'relu')
        
        self.min_scale = min_scale
        self.max_scale = max_scale
        
        if max_scale is None:
            self.pre_init_scale = inverse_softplus(init_scale)
        else:
            self.pre_init_scale = inverse_sigmoid(init_scale / max_scale)
        
        
#     def dist(self, batch_state, batch_pos):
#         input_dim = batch_state.dim()
#         if input_dim > 4:
#             batch_shape = batch_state.shape[:-3] 
#             data_shape = batch_state.shape[-3:]
#             batch_state = batch_state.view(-1, *data_shape)
#         batch_h = self.conv_net(batch_state)[..., 0, 0]
#         if pos:
#             batch_h = torch.cat((batch_h, batch_pos), dim=-1)

#         batch_loc, batch_pre_scale = self.mlp(batch_h).chunk(2, dim=-1)

#         if self.max_scale is None:
#             batch_scale = self.min_scale + F.softplus(self.pre_init_scale + batch_pre_scale)
#         else:
#             batch_scale = self.min_scale + self.max_scale*torch.sigmoid(self.pre_init_scale + batch_pre_scale)
        
#         if input_dim > 4:
#             batch_loc = batch_loc.view(*batch_shape, self.z_dim)
#             batch_scale = batch_scale.view(*batch_shape, self.z_dim)
        
#         return torch_dist.Independent(
#             torch_dist.Normal(batch_loc, batch_scale)
#         , 1)
    
    
    def dist(self, batch_state):
        input_dim = batch_state.dim()
        if input_dim > 4:
            batch_shape = batch_state.shape[:-3] 
            data_shape = batch_state.shape[-3:]
            batch_state = batch_state.view(-1, *data_shape)
        batch_h = self.conv_net(batch_state)[..., 0, 0]
        batch_loc, batch_pre_scale = self.mlp(batch_h).chunk(2, dim=-1)

        if self.max_scale is None:
            batch_scale = self.min_scale + F.softplus(self.pre_init_scale + batch_pre_scale)
        else:
            batch_scale = self.min_scale + self.max_scale*torch.sigmoid(self.pre_init_scale + batch_pre_scale)
        
        if input_dim > 4:
            batch_loc = batch_loc.view(*batch_shape, self.z_dim)
            batch_scale = batch_scale.view(*batch_shape, self.z_dim)
        
        return torch_dist.Independent(
            torch_dist.Normal(batch_loc, batch_scale)
        , 1)


# class PriorPolicy(ToDeviceMixin, nn.Module):
#     def __init__(self, state_shape, z_dim, hidden_dim, n_hidden,
#                  min_scale=0.001, max_scale=None, init_scale=0.1):
#         super().__init__()
        
#         assert state_shape == (5, 84, 84)
#         assert hidden_dim == 128
        
#         self.z_dim = z_dim
#         self.vit = ViT(
#             image_size = 84,
#             patch_size = 21,
#             num_classes = 3,
#             dim = hidden_dim,
#             depth = 3,
#             heads = 8,
#             mlp_dim = 256,
#             dropout = 0.1,
#             emb_dropout = 0.1,
#             channels = 5,
#         )
#         self.mlp = MLP([128]*n_hidden + [2*z_dim], 'relu')
        
#         self.min_scale = min_scale
#         self.max_scale = max_scale
        
#         if max_scale is None:
#             self.pre_init_scale = inverse_softplus(init_scale)
#         else:
#             self.pre_init_scale = inverse_sigmoid(init_scale / max_scale)
        
#     def dist(self, batch_state):
#         input_dim = batch_state.dim()
#         batch_shape = batch_state.shape[:-3] 
#         data_shape = batch_state.shape[-3:]
#         batch_state = batch_state.view(-1, *data_shape)
#         batch_h = self.vit(batch_state)
#         batch_loc, batch_pre_scale = self.mlp(batch_h).chunk(2, dim=-1)

#         if self.max_scale is None:
#             batch_scale = self.min_scale + F.softplus(self.pre_init_scale + batch_pre_scale)
#         else:
#             batch_scale = self.min_scale + self.max_scale*torch.sigmoid(self.pre_init_scale + batch_pre_scale)
        
#         if input_dim > 4:
#             batch_loc = batch_loc.view(*batch_shape, self.z_dim)
#             batch_scale = batch_scale.view(*batch_shape, self.z_dim)
        
#         return torch_dist.Independent(
#             torch_dist.Normal(batch_loc, batch_scale)
#         , 1)
    
# class GELU(torch.nn.Module):
#     def forward(self, input: torch.Tensor) -> torch.Tensor:
#         return torch.nn.functional.gelu(input)
# torch.nn.modules.activation.GELU = GELU
        
    
# class LowPolicy(ToDeviceMixin, nn.Module):
#     def __init__(self, state_shape, action_dim, z_dim, hidden_dim, n_hidden):
#         super().__init__()
        
#         assert state_shape == (5, 84, 84)
        
#         self.action_dim = action_dim
#         self.conv_net = nn.Sequential(
#             nn.Conv2d(5, hidden_dim, kernel_size=4, stride=3),
#             nn.ReLU(),
#             nn.Conv2d(hidden_dim, hidden_dim, kernel_size=4, stride=3),
#             nn.ReLU(),
#             nn.Conv2d(hidden_dim, hidden_dim, kernel_size=4, stride=2),
#             nn.ReLU(),
#             nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1),
#         )
#         self.mlp = MLP([hidden_dim + z_dim] + [hidden_dim]*(n_hidden-1) + [action_dim], 'relu')
        
#     def dist(self, batch_state, batch_z):
#         input_dim = batch_state.dim()
#         if input_dim > 4:
#             batch_shape = batch_state.shape[:-3]
            
#             data_shape = batch_state.shape[len(batch_shape):]
#             batch_state = batch_state.view(-1, *data_shape)

#             data_shape = batch_z.shape[len(batch_shape):]
#             batch_z = batch_z.reshape(-1, *data_shape)
        
#         batch_h = self.conv_net(batch_state)[..., 0, 0]
#         batch_h_z = torch.cat([batch_h, batch_z], dim=-1)
#         batch_logits = self.mlp(batch_h_z)
        
#         if input_dim > 4:
#             batch_logits = batch_logits.view(*batch_shape, self.action_dim)
        
#         return torch_dist.Categorical(logits=batch_logits)


class LowPolicy(ToDeviceMixin, nn.Module):
    def __init__(self, action_dim, z_dim, hidden_dim, n_lstm):
        super().__init__()
        
        assert hidden_dim == 128
         
        self.action_dim = action_dim
        self.lstm = nn.LSTM(
            z_dim,
            hidden_dim, n_lstm, batch_first=True, proj_size=action_dim,
        )
        
    def dist(self, batch_z):
        batch_logits, _  = self.lstm(batch_z)     
        return torch_dist.Categorical(logits=batch_logits)


# class LowPolicy(ToDeviceMixin, nn.Module):
#     def __init__(self, action_dim, z_dim, hidden_dim, n_lstm):
#         super().__init__()
        
#         assert hidden_dim == 128
        
#         self.action_dim = action_dim
#         self.vit = ViT(
#             image_size = 84,
#             patch_size = 21,
#             num_classes = 3,
#             dim = 256,
#             depth = 3,
#             heads = 8,
#             mlp_dim = 256,
#             dropout = 0.1,
#             emb_dropout = 0.1,
#             channels = 5,
#         )

#         self.mlp = MLP([256 + z_dim] + [256]*(2-1) + [action_dim], 'relu')
        
#     def dist(self, batch_state, batch_z):
        
# #         input_dim = batch_state.dim()
# #         if input_dim > 4:
            
#         # batch_state = torch.flatten(batch_state, start_dim=0, end_dim=1)
#         # batch_z = torch.flatten(batch_z, start_dim=0, end_dim=1)
#         data_shape = batch_state.shape[:-3]
#         batch_state = batch_state.reshape(-1, *batch_state.shape[-3:])
#         batch_z = batch_z.reshape(-1, *batch_z.shape[-1:])
#         batch_logits = self.vit(batch_state)
#         print(batch_logits.shape, batch_z.shape)
#         batch_logits = torch.cat([batch_logits, batch_z], dim=1)
#         batch_logits = self.mlp(batch_logits)
#         batch_logits = F.log_softmax(batch_logits, dim=-1)
#         batch_logits = batch_logits.reshape(*data_shape, -1)
#         return torch_dist.Categorical(logits=batch_logits)
    
# class GELU(torch.nn.Module):
#     def forward(self, input: torch.Tensor) -> torch.Tensor:
#         return torch.nn.functional.gelu(input)
# torch.nn.modules.activation.GELU = GELU


load = torch.load('../checkpoints/'+ckpt_name, map_location='cpu')
prior_policy = load['prior_policy']
low_policy = load['low_policy']
prior_policy.eval()
low_policy.eval()

LowPolicy(
  (lstm): LSTM(20, 128, proj_size=3, num_layers=2, batch_first=True)
)

In [5]:
sys.path.append('C:\\Users\\Lee Geonju\\Desktop\\drone\\PythonClient')
sys.path.append('C:\\Users\\Lee Geonju\\Desktop\\drone\\PythonClient\\reinforcement_learning')

import time
import torch
import wandb

import multirotor.setup_path
import airsim
import numpy as np
import math
from argparse import ArgumentParser

import wandb
import airsim
import numpy as np
import torch
from airgym.envs.airsim_env import AirSimEnv
from gym import spaces


class AirSimDroneEnv(AirSimEnv):
    def __init__(self, ip_address, step_length, image_shape):
        super().__init__(image_shape)
        self.step_length = step_length
        self.image_shape = image_shape

        self.state = {
            "position": np.zeros(3),
            "collision": False,
            "prev_position": np.zeros(3),
        }

        self.drone = airsim.MultirotorClient(ip=ip_address)
        self.action_space = spaces.Discrete(7)
        self._setup_flight()

        self.image_request = airsim.ImageRequest(
            "0", airsim.ImageType.DepthVis, True, False
            #"0", airsim.ImageType.DepthPlanar, True, False
        )

    def __del__(self):
        self.drone.reset()

    def _setup_flight(self):
        self.drone.reset()
        self.drone.enableApiControl(True)
        self.drone.armDisarm(True)
        
        # Set home position and velocity
        self.drone.takeoffAsync().join()
        #self.drone.moveToPositionAsync(212, -320, -19.0225, 10).join()
        #self.drone.moveByVelocityAsync(1, -0.67, -0.8, 5).join()

    def transform_obs(self, responses):
        img1d = np.array(responses[0].image_data_float, dtype=np.float)
        img1d = 255*img1d.clip(0, 1)
        img2d = np.reshape(img1d, (responses[0].height, responses[0].width))

        from PIL import Image

        image = Image.fromarray(img2d)
        im_final = np.array(image.resize((84, 84)).convert("L"))

        return im_final.reshape([84, 84, 1])

    def _get_obs(self):
        responses = self.drone.simGetImages([self.image_request])
        image = self.transform_obs(responses)
        self.drone_state = self.drone.getMultirotorState()

        self.state["prev_position"] = self.state["position"]
        self.state["position"] = self.drone_state.kinematics_estimated.position
        self.state["velocity"] = self.drone_state.kinematics_estimated.linear_velocity

        collision = self.drone.simGetCollisionInfo().has_collided
        self.state["collision"] = collision
        return image
        

    def _do_action(self, action):
        import time
        start = time.time()
        if action == 0:  # turn left
            self.drone.rotateByYawRateAsync(-8, 1)#.join()
        elif action == 1:  # turn right
            self.drone.rotateByYawRateAsync(8, 1)#.join()
        elif action == 2:  # forward
            future = self.drone.moveByVelocityBodyFrameAsync(5, 0, 0, 100)
#             future = self.drone.moveByVelocityBodyFrameAsync(5, 0, 0, 1).join()
            #for vel in np.linspace(5, 0, 3):
            #    future = self.drone.moveByVelocityBodyFrameAsync(vel, 0, 0, 2)
            #future.join()
        time.sleep(1)

    def _compute_reward(self):
        thresh_dist = 7
        beta = 1

        z = -10
        pts = [
            np.array([-0.55265, -31.9786, -19.0225]),
            np.array([48.59735, -63.3286, -60.07256]),
            np.array([193.5974, -55.0786, -46.32256]),
            np.array([369.2474, 35.32137, -62.5725]),
            np.array([541.3474, 143.6714, -32.07256]),
        ]

        quad_pt = np.array(
            list(
                (
                    self.state["position"].x_val,
                    self.state["position"].y_val,
                    self.state["position"].z_val,
                )
            )
        )

        if self.state["collision"]:
            reward = -100
        else:
            dist = 10000000
            for i in range(0, len(pts) - 1):
                dist = min(
                    dist,
                    np.linalg.norm(np.cross((quad_pt - pts[i]), (quad_pt - pts[i + 1])))
                    / np.linalg.norm(pts[i] - pts[i + 1]),
                )

            if dist > thresh_dist:
                reward = -10
            else:
                reward_dist = math.exp(-beta * dist) - 0.5
                reward_speed = (
                    np.linalg.norm(
                        [
                            self.state["velocity"].x_val,
                            self.state["velocity"].y_val,
                            self.state["velocity"].z_val,
                        ]
                    )
                    - 0.5
                )
                reward = reward_dist + reward_speed

        done = 0
        if reward <= -10:
            done = 1

        return reward, done

    def step(self, action):
        self._do_action(action)
        obs = self._get_obs()
        reward, done = self._compute_reward()

        return obs, reward, done, self.state

    def reset(self):
        self._setup_flight()
        return self._get_obs(), self.state

session_id = wandb.sdk.lib.runid.generate_id()
# env = AirSimDroneEnv('137.68.192.71', 1, (84, 84))
env = AirSimDroneEnv('127.0.0.1', 1, (84, 84))
print('done')

done


In [6]:
from collections import deque

In [7]:
from torchvision.transforms import Compose, RandomRotation, CenterCrop, Resize
from torchvision.transforms.functional import rotate
augment = Compose([
    # RandomRotation(5),
    CenterCrop(75),
    Resize(resize_res)
])

In [8]:
state_queue = deque(maxlen=frame_stack)

In [9]:
def quantize(img, num_level):
    alpha_q = 0
    beta_q = num_level - 1
    alpha = 0
    beta = 255

    s = (beta - alpha) / (beta_q - alpha_q)
    z = round((beta * alpha_q - alpha * beta_q) / (beta - alpha))

    return (img / s).round() + z


def quantize_with_thresholds(img, thresh_list):
    new_img = torch.zeros_like(img)
    n_levels = len(thresh_list) + 1
    for i in range(1, n_levels):
        new_img[img > thresh_list[i-1]] = i
    return new_img
    

def process_state(state):
    state = torch.as_tensor(state).float().view(84, 84)
    state_queue.append(state)
    
    stack = list(state_queue)
    if len(stack) < frame_stack:
        stack = stack + [stack[-1]]*(frame_stack-len(stack))
    stack = torch.stack(stack)
    return augment(stack)

In [10]:
rollouts = []
state = process_state(env.reset()[0])

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  img1d = np.array(responses[0].image_data_float, dtype=np.float)


In [11]:
# torch.save({'rollouts': rollouts}, 'spirl_rollouts.pt')

In [12]:
state_queue = deque(maxlen=frame_stack)
state, info = env.reset()
state = process_state(state)
position = torch.tensor((info['position'].x_val, info['position'].y_val, info['position'].z_val,))

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  img1d = np.array(responses[0].image_data_float, dtype=np.float)


In [None]:
positions = []
collision_count = 0
idx = 0
for i in range(1, 3001):
    if idx % skill_length == 0:
        if pos:
            z = prior_policy.dist(state, position).sample()[None, :].expand(1, skill_length, -1)
        else:
            z = prior_policy.dist(state).sample()[None, None, :].expand(1, skill_length, -1)
#         actions = low_policy.dist(state, z).logits.argmax(-1).squeeze(0)
        actions = low_policy.dist(z).logits.argmax(-1).squeeze(0)
        print(f'step {i}: skill updated', actions)
        # a = low_policy.dist(state, z).logits.argmax()
        # a = low_policy.dist(z).logits.argmax()
    a = actions[idx%skill_length]
    idx = (idx + 1) % skill_length
    
    state, _, _, info = env.step(a.detach().cpu().numpy())
#     plt.imshow(state)
#     plt.show()
    position = torch.tensor((info['position'].x_val, info['position'].y_val, info['position'].z_val,))
    state = process_state(state)
    if is_quantized:
        if num_level != 0:
            if len(thresholds) != 0:
                raise
            state = quantize(state, num_level) / (num_level - 1)
        elif len(thresholds) != 0:
            state = quantize_with_thresholds(state, thresholds) / len(thresholds)
        else:
            raise
    else:
        state = state / 255
    
    if info['collision'] is True:
        collision_count += 1
        print(f'step {i} collision!, {collision_count} time(s) collided')
        state_queue = deque(maxlen=frame_stack)
        state, info = env.reset()
        state = process_state(state)
        position = torch.tensor((info['position'].x_val, info['position'].y_val, info['position'].z_val,))
        if is_quantized:
            state = quantize(state, num_level)
        rollouts.append(positions)
        positions = []
        
##         z = prior_policy.dist(state).sample()[None, None, :].expand(1, skill_length, -1)

#         z = prior_policy.dist(state).sample()[None, :].expand(1, skill_length, -1)
#         actions = low_policy.dist(z).logits.argmax(-1).squeeze(0)

##         actions = low_policy.dist(state, z).logits.argmax(-1).squeeze(0)
        idx = 0
    
    positions.append(info['position'])
print('collided ', collision_count, ' times')
# state, _, _, info = env.step(0)
# state = process_state(state)

step 1: skill updated tensor([2, 2, 2])


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  img1d = np.array(responses[0].image_data_float, dtype=np.float)


step 4: skill updated tensor([2, 2, 1])
step 7: skill updated tensor([0, 0, 2])
step 10: skill updated tensor([0, 0, 0])
step 13: skill updated tensor([0, 0, 0])
step 16: skill updated tensor([0, 0, 0])
step 19: skill updated tensor([0, 0, 0])
step 22: skill updated tensor([0, 0, 0])
step 25: skill updated tensor([0, 0, 0])
step 28: skill updated tensor([0, 0, 0])
step 31: skill updated tensor([2, 2, 2])
step 34: skill updated tensor([2, 2, 2])
step 37: skill updated tensor([2, 0, 0])
step 40: skill updated tensor([2, 1, 1])
step 43: skill updated tensor([1, 1, 1])
step 46: skill updated tensor([1, 2, 1])
step 49: skill updated tensor([2, 2, 0])
step 52: skill updated tensor([2, 0, 0])
step 55: skill updated tensor([0, 0, 0])
step 58: skill updated tensor([0, 0, 0])
step 61: skill updated tensor([1, 1, 1])
step 64: skill updated tensor([1, 2, 1])
step 67: skill updated tensor([2, 1, 1])
step 70: skill updated tensor([0, 2, 0])
step 73: skill updated tensor([0, 0, 0])
step 76: skill upd

In [None]:
prior_policy.conv_net

In [None]:
# rollouts = torch.load('./spirl_rollouts.pt')['rollouts']

In [None]:
print(state.shape, state.min(), state.max())

In [None]:
prior_policy.dist(state) # [10]

In [None]:
low_policy.dist(z)

In [None]:
actions = low_policy.dist(z).logits.argmax(-1).squeeze(0)
actions.shape

In [None]:
# l, r, d, u = 63, 225, -369, -81
l, r, d, u = 20, 225, -369, -81
plt.figure(figsize=((r - l)//32, abs(d - u)//32))
plt.imshow(
    (1-0.3*(1-plt.imread('map.png'))),
    extent=(l, r, u, d)
)
plt.xlim(l, r);plt.ylim(u, d)

for positions in rollouts[:]:
    if len(positions) < 15:
        continue
    plt.plot(*np.array([
        [position.x_val, position.y_val]
        for position in positions
    ]).T, c='royalblue', linewidth=.6, alpha=0.6)
    plt.scatter(positions[0].x_val, positions[0].y_val, marker='^', c='black', s=100)
    plt.scatter(positions[-1].x_val, positions[-1].y_val, marker='x', c='red', s=100, linewidth=1)

In [None]:
max_len = 0
for i, positions in enumerate(rollouts):
    if len(positions) > max_len:
        max_len = len(positions)
        print(f'max_len:{max_len}, idx: {i}')

In [None]:
#########################################################################################################################################