In [1]:
import os
import numpy as np
import random as rd
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

!pip install datasets
!pip install conllu

import torch
from functools import partial
from datasets import load_dataset

!pip install evaluate

import matplotlib.pyplot as plt

!pip install "stable-baselines3[extra]>=2.0.0a4"
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env

Collecting datasets
  Downloading datasets-2.17.0-py3-none-any.whl (536 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/536.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━[0m [32m317.4/536.6 kB[0m [31m9.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.6/536.6 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=12.0.0 (from datasets)
  Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.3/38.3 MB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-

In [2]:
# the function returns whether a tree is projective or not. It is currently
# implemented inefficiently by brute checking every pair of arcs.
def is_projective(tree):
  for i in range(len(tree)):
    if tree[i] == -1:
      continue
    left = min(i, tree[i])
    right = max(i, tree[i])

    for j in range(0, left):
      if tree[j] > left and tree[j] < right:
        return False
    for j in range(left+1, right):
      if tree[j] < left or tree[j] > right:
        return False
    for j in range(right+1, len(tree)):
      if tree[j] > left and tree[j] < right:
        return False

  return True

# the function creates a dictionary of word/index pairs: our embeddings vocabulary
# threshold is the minimum number of appearance for a token to be included in the embedding list
def create_dict(dataset, threshold=3):
  dic = {}  # dictionary of word counts
  for sample in dataset:
    for word in sample['tokens']:
      if word in dic:
        dic[word] += 1
      else:
        dic[word] = 1

  map = {}  # dictionary of word/index pairs. This is our embedding list
  map["<pad>"] = 0
  map["<ROOT>"] = 1
  map["<unk>"] = 2 #used for words that do not appear in our list

  next_indx = 3
  for word in dic.keys():
    if dic[word] >= threshold:
      map[word] = next_indx
      next_indx += 1

  return map

  and should_run_async(code)


In [3]:
train_dataset = load_dataset('universal_dependencies', 'grc_proiel', split="train")
dev_dataset = load_dataset('universal_dependencies', 'grc_proiel', split="validation")
test_dataset = load_dataset('universal_dependencies', 'grc_proiel', split="test")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/3.81M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/296k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/293k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15014 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1019 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1047 [00:00<?, ? examples/s]

In [4]:
# remove the non projective trees in the train dataset
#train_dataset = [sample for sample in train_dataset if is_projective([-1] + [int(head) for head in sample["head"]])]

# create the embedding dictionary
emb_dictionary = create_dict(train_dataset)

In [5]:
def process_sample(sample, get_gold_path = False):

  # put sentence and gold tree in our format
  sentence = ["<ROOT>"] + sample["tokens"]
  gold = [-1] + [int(i) for i in sample["head"]]  #heads in the gold tree are strings, we convert them to int

  # embedding ids of sentence words
  enc_sentence = [emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"] for word in sentence]

  return enc_sentence, sentence, gold

In [6]:
def prepare_batch(batch_data):
  data = [process_sample(s) for s in batch_data]
  # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
  enc_sentences = [s[0] for s in data] # input_ids
  sentences = [s[1] for s in data] # sentences
  trees = [s[2] for s in data] # gold_tree
  return enc_sentences, sentences, trees

In [7]:
BATCH_SIZE = 10

bilstm_train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch))
bilstm_dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))
bilstm_test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))

In [8]:
class PrioritizedReplayBuffer():
  def __init__(self, max_size, input_shape, alpha=0.9):
    """
    Initialize the Prioritized Replay Buffer.

    Args:
        max_size (int): The maximum size of the buffer.
        input_shape (tuple): The shape of the inputs.
        alpha (float): Determines how much prioritization is used, with 0 corresponding to no prioritization.
    """
    self.mem_size = max_size
    self.mem_cntr = 0
    self.alpha = alpha  # The exponent alpha determines how much prioritization is used

    # Initialize memory for states, actions, rewards, terminal flags, and priorities
    self.state_memory = np.zeros((self.mem_size, *input_shape), dtype=np.float32)
    self.new_state_memory = np.zeros((self.mem_size, *input_shape), dtype=np.float32)
    self.action_memory = np.zeros(self.mem_size, dtype=np.int64)
    self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
    self.terminal_memory = np.zeros(self.mem_size, dtype=np.uint8)
    self.priority_memory = np.zeros(self.mem_size, dtype=np.float32) + 1e-5  # Initialize with small positive values
    self.max_priority = 1.0  # Initial max priority

  def store_transition(self, state, action, reward, state_, done):
    """
    Store a transition in the buffer.

    Args:
        state: The state of the environment before the action.
        action: The action taken.
        reward: The reward received.
        state_: The state of the environment after the action.
        done: Whether the episode has ended.
    """
    index = self.mem_cntr % self.mem_size  # Circular buffer

    # Store the transition in the respective memory arrays
    self.state_memory[index] = state
    self.new_state_memory[index] = state_
    self.action_memory[index] = action
    self.reward_memory[index] = reward
    self.terminal_memory[index] = done

    # Assign the max priority seen so far to new experiences
    self.priority_memory[index] = self.max_priority

    self.mem_cntr += 1

  def sample_buffer(self, batch_size, beta=0.5):
    """
    Sample a batch of transitions from the buffer.

    Args:
        batch_size (int): The size of the batch to sample.
        beta (float): The exponent for adjusting the importance-sampling weights.

    Returns:
        Tuple containing states, actions, rewards, next states, terminals, indices of the sampled transitions, and the importance-sampling weights.
    """
    # Determine the range of memory to sample from
    num_sampled_elements = min(self.mem_cntr, self.mem_size)
    priorities = self.priority_memory[:num_sampled_elements]

    # Normalize priorities and convert to probabilities
    scaled_priorities = np.power(priorities, self.alpha)
    sample_probs = scaled_priorities / np.sum(scaled_priorities)

    # Sample experiences based on probabilities
    chosen_indices = np.random.choice(num_sampled_elements, batch_size, replace=False, p=sample_probs)

    # Retrieve sampled experiences
    states = self.state_memory[chosen_indices]
    actions = self.action_memory[chosen_indices]
    rewards = self.reward_memory[chosen_indices]
    states_ = self.new_state_memory[chosen_indices]
    terminal = self.terminal_memory[chosen_indices]

    # Compute importance-sampling weights and adjust with beta
    weights = np.power(self.mem_size * sample_probs[chosen_indices], -beta)
    weights /= np.max(weights)  # Normalize for stability
    weights = torch.tensor(weights, dtype=torch.float32).view(-1, 1)  # Convert to tensor and reshape

    return states, actions, rewards, states_, terminal, chosen_indices, weights

  def update_priorities(self, indices, priorities):
    """
    Update the priorities of the sampled transitions in a vectorized manner.

    Args:
        indices (list or numpy.ndarray): Indices of the sampled transitions.
        priorities (list or numpy.ndarray): New priorities for the sampled transitions.
    """
    # Ensure that no priority is set to exactly 0 by using np.maximum, as a priority of 0 would mean a transition is never sampled.
    # Adding a small value (1e-5) ensures all priorities are non-zero and transitions have a chance of being sampled.
    priorities = np.maximum(priorities, 1e-5)

    # Update the priorities in a vectorized manner.
    # This is generally faster and more efficient than a loop, especially for large arrays.
    self.priority_memory[indices] = priorities

    # Update the maximum priority with the largest priority in the new set.
    # This value is used to set the priority for new experiences (ensuring they have a high chance of being sampled initially).
    self.max_priority = max(self.max_priority, np.max(priorities))

In [9]:
class DuelingDeepQNetwork(nn.Module):
  def __init__(self, lr, n_actions, name, input_dims, chkpt_dir):
    """
    Initialize the Dueling Deep Q Network.

    Args:
      lr (float): Learning rate for the optimizer.
      n_actions (int): Number of possible actions.
      name (str): Name of the network, used in saving and loading models.
      input_dims (tuple): Dimensions of the input state.
      chkpt_dir (str): Directory where the checkpoints (model weights) are saved.
    """
    super(DuelingDeepQNetwork, self).__init__()
    self.chkpt_dir = chkpt_dir
    self.checkpoint_file = os.path.join(self.chkpt_dir, name)

    # Define the first fully connected layer
    self.fc1 = nn.Linear(*input_dims, 128)
    # Define the layer for estimating the state-value function V
    self.V = nn.Linear(128, 1)
    # Define the layer for estimating the advantage function A
    self.A = nn.Linear(128, n_actions)

    # Set up the optimizer (Adam) and the loss function (Mean Squared Error)
    self.optimizer = optim.Adam(self.parameters(), lr=lr)
    self.loss = nn.MSELoss()
    # Define the device (use GPU if available)
    self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
    self.to(self.device)

  def forward(self, state):
    """
    Perform a forward pass through the network.

    Args:
      state (torch.Tensor): The input state.

    Returns:
      V (torch.Tensor): The estimated state-value function.
      A (torch.Tensor): The estimated advantage function.
    """
    flat1 = F.relu(self.fc1(state))  # Pass the state through the first fully connected layer
    V = self.V(flat1)  # Compute the state-value function
    A = self.A(flat1)  # Compute the advantage function

    return V, A

  def save_checkpoint(self):
    """
    Save the model's current state.
    """
    print('...saving checkpoint...')
    T.save(self.state_dict(), self.checkpoint_file)

  def load_checkpoint(self):
    """
    Load the model's state from a saved checkpoint.
    """
    print('...loading checkpoint...')
    self.load_state_dict(T.load(self.checkpoint_file))

In [10]:
class Agent():
  def __init__(self, gamma, epsilon, lr, n_actions, input_dims, mem_size, batch_size, eps_min=0.01, eps_dec=5e-7, replace=1000, beta_start=0.5, beta_increment_per_sampling=0.001, beta_max=1.0, chkpt_dir='tmp/dueling_ddqn'):
    """
    Initialize the agent with given hyperparameters and network parameters.

    Args:
      gamma (float): discount factor for future rewards.
      epsilon (float): initial exploration rate for epsilon-greedy action selection.
      lr (float): learning rate for updating the neural network.
      n_actions (int): number of possible actions the agent can take.
      input_dims (tuple): dimensions of the input features.
      mem_size (int): size of the replay memory.
      batch_size (int): number of experiences sampled from memory for each learning step.
      eps_min (float): minimum value for epsilon (exploration rate).
      eps_dec (float): decrement value for epsilon after each episode.
      replace (int): number of steps after which the target network weights are updated.
      beta_start (float): initial value of beta for importance-sampling weights.
      beta_increment_per_sampling (float): increment value for beta after each sampling.
      beta_max (float): maximum value for beta.
      chkpt_dir (str): directory where model checkpoints are saved.
    """
    # Initialize parameters
    self.gamma = gamma
    self.epsilon = epsilon
    self.lr = lr
    self.n_actions = n_actions
    self.input_dims = input_dims
    self.batch_size = batch_size
    self.eps_min = eps_min
    self.eps_dec = eps_dec
    self.replace_target_cnt = replace
    self.beta = beta_start
    self.beta_increment_per_sampling = beta_increment_per_sampling
    self.beta_max = beta_max
    self.chkpt_dir = chkpt_dir
    self.learn_step_counter = 0
    self.action_space = [i for i in range(self.n_actions)]

    # Initialize memory and Dueling DQNs for current and target network
    self.memory = PrioritizedReplayBuffer(mem_size, input_dims)
    self.q_eval = DuelingDeepQNetwork(lr, n_actions, 'q_eval', input_dims, chkpt_dir)
    self.q_next = DuelingDeepQNetwork(lr, n_actions, 'q_next', input_dims, chkpt_dir)

    # Initialize variables for averaging network weights
    self.average_q_eval_state_dict = None  # To store the averaged state dict of the Q_eval network
    self.networks_counter = 0  # To count the number of networks added to the average

  def choose_action(self, observation):
    """
    Choose an action based on the current state and the epsilon-greedy policy.

    Args:
      observation (np.array): the current state observation.

    Returns:
      action (int): the action chosen by the agent.
    """
    if np.random.random() > self.epsilon:
      # Exploitation: choose the best action according to the network's output
      state = T.tensor(np.array(observation), dtype=T.float32).to(self.q_eval.device)
      _, advantage = self.q_eval.forward(state)
      action = T.argmax(advantage).item()
    else:
      # Exploration: choose a random action
      action = np.random.choice(self.action_space)
    return action

  def store_transition(self, state, action, reward, state_, done):
    """
    Store a transition in the replay buffer.

    Args:
      state (np.array): the starting state.
      action (int): the action taken.
      reward (float): the reward received.
      state_ (np.array): the next state after taking the action.
      done (bool): whether the episode is finished.
    """
    self.memory.store_transition(state, action, reward, state_, done)

  def replace_target_network(self):
    """
    Update the target network by copying the weights from the evaluation network.
    This happens every 'replace_target_cnt' learning steps.
    """
    if self.learn_step_counter % self.replace_target_cnt == 0:
      self.q_next.load_state_dict(self.q_eval.state_dict())

  def update_average_network(self, current_state_dict):
    """
    Update the running average of the Q_eval network weights.
    This is intended to stabilize the training by smoothing out the variations in the network weights over training steps.

    Args:
      current_state_dict (dict): state_dict of the current Q_eval network.
    """
    self.networks_counter += 1
    if self.average_q_eval_state_dict is None:
      self.average_q_eval_state_dict = {k: v.clone().detach() for k, v in current_state_dict.items()}
    else:
      new_average_q_eval_state_dict = {}
      for key in self.average_q_eval_state_dict.keys():
        new_average_q_eval_state_dict[key] = (
            self.average_q_eval_state_dict[key] * (self.networks_counter - 1)
            + current_state_dict[key]
        ) / self.networks_counter
      self.average_q_eval_state_dict = new_average_q_eval_state_dict

  def decrement_epsilon(self):
    """
    Decrement the epsilon value to reduce exploration over time.
    """
    self.epsilon = max(self.epsilon - self.eps_dec, self.eps_min)

  def save_models(self):
    """
    Save the current and target network models.
    """
    self.q_eval.save_checkpoint()
    self.q_next.save_checkpoint()

  def load_models(self):
    """
    Load the saved models for the current and target networks.
    """
    self.q_eval.load_checkpoint()
    self.q_next.load_checkpoint()

  def learn(self):
    """
    The learning process for the agent. Samples a batch of experiences and updates the network.
    """
    if self.memory.mem_cntr < self.batch_size:
      return  # Do not learn until enough samples are available

    self.q_eval.optimizer.zero_grad()

    # Update the target network and the average network at the specified intervals
    if self.learn_step_counter % self.replace_target_cnt == 0:
      self.replace_target_network()
      self.update_average_network(self.q_eval.state_dict())

    # Sample a batch from the replay buffer
    states, actions, rewards, states_, dones, indices, weights = self.memory.sample_buffer(self.batch_size, self.beta)

    states = T.tensor(states).to(self.q_eval.device)
    actions = T.tensor(actions).to(self.q_eval.device)
    dones = T.tensor(dones).to(self.q_eval.device)
    rewards = T.tensor(rewards).to(self.q_eval.device)
    states_ = T.tensor(states_).to(self.q_eval.device)
    weights = weights.clone().detach().requires_grad_(True).to(self.q_eval.device)

    batch_indices = np.arange(self.batch_size)

    # Load the averaged network weights for predicting the next Q-values
    self.q_eval.load_state_dict(self.average_q_eval_state_dict)

    V_s, A_s = self.q_eval.forward(states)
    V_s_avg, A_s_avg = self.q_eval.forward(states_)

    q_pred = T.add(V_s, (A_s - A_s.mean(dim=1, keepdim=True)))[batch_indices, actions]
    q_next = T.add(V_s_avg, (A_s_avg - A_s_avg.mean(dim=1, keepdim=True)))
    q_next[dones.bool()] = 0.0  # Set Q value of next state to 0 if the episode ended
    q_target = rewards + self.gamma * q_next[batch_indices, T.argmax(A_s_avg, dim=1)]

    # Compute loss, perform backpropagation, and update network weights
    loss = self.q_eval.loss(q_target, q_pred) * weights  # Apply importance-sampling weights
    loss = loss.mean()  # Average the loss over the batch
    loss.backward()
    self.q_eval.optimizer.step()

    # Update learning step counter and epsilon
    self.learn_step_counter += 1
    self.decrement_epsilon()

    # Increment beta, ensuring it doesn't exceed beta_max
    self.beta = min(self.beta + self.beta_increment_per_sampling, self.beta_max)

    # Update the priorities in the replay buffer based on TD error
    td_errors = (q_target - q_pred).detach().cpu().numpy()
    new_priorities = np.abs(td_errors) + 1e-5  # Ensure priorities are non-zero
    self.memory.update_priorities(indices, new_priorities)


In [11]:
class ArcStandard:
  def __init__(self, sentence, tree):
    self.gold_tree = tree
    self.sentence = sentence
    self.buffer = [i for i in range(len(self.sentence))]
    self.stack = []
    self.arcs = [-1 for _ in range(len(self.sentence))]
    self.prev_actions = [None, None, None, None, None]

    # three shift moves to initialize the stack
    self.shift()
    self.shift()
    if len(self.sentence) > 2:
      self.shift()

    self.loss = [0 for i in range(len(self.stack))]

  def shift(self):
    b1 = self.buffer[0]
    self.buffer = self.buffer[1:]
    self.stack.append(b1)
    if len(self.prev_actions) == 5:
      self.prev_actions.pop(0)
    self.prev_actions.append('shift')

  def left_arc(self):
    o1 = self.stack.pop()
    o2 = self.stack.pop()
    self.arcs[o2] = o1
    self.stack.append(o1)
    if len(self.prev_actions) == 5:
      self.prev_actions.pop(0)
    self.prev_actions.append('left_arc')
    if len(self.stack) < 2 and len(self.buffer) > 0:
      self.shift()
      if len(self.prev_actions) == 5:
        self.prev_actions.pop(0)
      self.prev_actions.append('shift')

  def right_arc(self):
    o1 = self.stack.pop()
    o2 = self.stack.pop()
    self.arcs[o1] = o2
    self.stack.append(o2)
    if len(self.prev_actions) == 5:
      self.prev_actions.pop(0)
    self.prev_actions.append('right_arc')
    if len(self.stack) < 2 and len(self.buffer) > 0:
      self.shift()
      if len(self.prev_actions) == 5:
        self.prev_actions.pop(0)
      self.prev_actions.append('shift')

  def is_tree_final(self):
    return len(self.stack) == 1 and len(self.buffer) == 0

  def print_configuration(self):
    s = [self.sentence[i] for i in self.stack]
    b = [self.sentence[i] for i in self.buffer]
    print(s, b)
    print(self.arcs)

  def get_valid_actions(self):
    """
    Determine the valid actions that can be taken from the current state of the parser.

    Returns:
      list: A list of valid actions.
    """
    valid_actions = ['shift', 'left_arc', 'right_arc']

    # 'shift' is not valid if the buffer is empty
    if len(self.buffer) == 0:
      valid_actions.remove('shift')

    # 'left_arc' is not valid if:
    # 1. The stack has less than 2 elements
    # 2. The stack has exactly 2 elements but the buffer is not empty
    # 3. The second-to-last element on the stack is the root (0)
    if len(self.stack) < 2 or (len(self.stack) == 2 and len(self.buffer) != 0) or self.stack[-2] == 0:
      valid_actions.remove('left_arc')

    # 'right_arc' is not valid if:
    # 1. The stack has less than 2 elements
    # 2. The second-to-last element on the stack is the root (0) and the buffer is not empty
    if len(self.stack) < 2 or (self.stack[-2] == 0 and len(self.buffer) != 0):
      valid_actions.remove('right_arc')

    return valid_actions

  def get_binary_features(self, N=10):
    """
    Construct the binary feature vector for the current state.
    Each of the top 10 tokens from the stack and the first 10 tokens from the buffer
    will have their gold head position binary encoded using 5 bits, and additional bits
    indicating if the gold head is lost and if all dependents are already collected.

    Args:
      N (int): The number of tokens from the stack and buffer to consider.

    Returns:
      np.array: The binary feature vector representing the current state.
    """
    # Initialize the binary feature vector
    binary_features = []

    # Get the top N tokens from the stack and the first N tokens from the buffer
    stack_elements = self.stack[-N:] if len(self.stack) >= N else self.stack + [-1] * (N - len(self.stack))
    buffer_elements = self.buffer[:N] if len(self.buffer) >= N else self.buffer + [-1] * (N - len(self.buffer))

    # Combine stack and buffer elements for easier indexing
    combined_elements = stack_elements + buffer_elements

    # Get the complete stack and buffer for checking if the gold head is lost
    complete_elements = self.stack + self.buffer

    # Encode the position of the gold head for each element in stack_elements and buffer_elements
    for token_index in combined_elements:
      if token_index == -1:
        binary_features.extend([-1, -1, -1, -1, -1, -1, -1])  # Padding representation with 7 bits
      else:
        if token_index == 0:
          binary_features.extend([1, 1, 1, 1, 1, 0] + [self.has_collected_all_dependents(token_index)])
        else:
          gold_head = self.gold_tree[token_index]
          gold_head_pos = combined_elements.index(gold_head) if gold_head in combined_elements else -1
          gold_head_lost = 1 if (self.gold_tree[token_index] not in complete_elements and token_index != 0) else 0
          all_dependents_collected = self.has_collected_all_dependents(token_index)
          if gold_head_pos == -1:
            binary_features.extend([-1, -1, -1, -1, -1, gold_head_lost, all_dependents_collected])  # Gold head is lost or not in the 20 elements
          else:
            binary_features.extend([int(bit) for bit in self.position_to_binary(gold_head_pos)] + [gold_head_lost, all_dependents_collected])
    # Encode the last 5 (or fewer, with padding) actions leading to this state
    binary_features.extend(self.get_padded_prev_actions(self.prev_actions))

    # Encode all valid actions in this state
    # Assuming 'get_valid_actions' returns a list of valid actions in the current state
    valid_actions = self.get_valid_actions()
    binary_features.extend([1 if action in valid_actions else 0 for action in ['shift', 'left_arc', 'right_arc']])

    return np.array(binary_features)

  def position_to_binary(self, pos, max_pos=20):
    """
    Convert a position to a 5-bit binary representation.
    If the position is out of range (lost or not among the 20 elements), return '00000'.

    Args:
        pos (int): The position to be converted.
        max_pos (int): The maximum position value (20 for top 10 in stack and first 10 in buffer).

    Returns:
        str: A 5-bit binary string representing the position.
    """
    if pos < 0 or pos >= max_pos:
        return '00000'
    return format(pos, '05b')

  def has_collected_all_dependents(self, first_common_parent):
    for token in self.stack:
      if self.gold_tree[token] == first_common_parent:
        return 0

    for token in self.buffer:
      if self.gold_tree[token] == first_common_parent:
        return 0

    return 1

  def action_to_binary(self, action):
    """
    Convert an action to its binary (one-hot encoded) representation.

    Args:
      action (str): The action to be converted.

    Returns:
      list: The binary representation of the action.
    """
    if action == 'left_arc':
      return [1, 0]
    elif action == 'right_arc':
      return [0, 1]
    elif action == 'shift':
      return [1, 1]
    else:  # For padding or unknown actions
      return [0, 0]

  def get_padded_prev_actions(self, prev_actions, max_prev_actions=5):
    """
    Get the binary representations of previous actions, padded with zeros if there are fewer than 'max_prev_actions'.

    Args:
      prev_actions (list): The list of the last few actions taken.
      max_prev_actions (int): The maximum number of previous actions to consider.

    Returns:
      list: A flattened list containing the binary representations of previous actions, padded with zeros.
    """
    # Convert each previous action to its binary representation
    binary_prev_actions = [self.action_to_binary(action) for action in prev_actions]

    # Calculate the number of actions to pad
    num_padding = max_prev_actions - len(binary_prev_actions)

    # Pad with vectors representing 'no action'
    binary_prev_actions.extend([self.action_to_binary(None)] * num_padding)

    # Flatten the list of binary vectors into a single list
    return [bit for action_bits in binary_prev_actions for bit in action_bits]


In [27]:
class DependencyParsingEnv(gym.Env):
  metadata = {'render.modes': ['human']}

  def __init__(self, sentence, tree, max_steps_per_episode=5):
    super(DependencyParsingEnv, self).__init__()
    self.sentence = sentence
    self.tree = tree
    self.parser = ArcStandard(sentence, tree)
    self.previous_action = [-1, 0]
    self.positive_reward = 100
    self.current_step = 0
    self.max_steps_per_episode = max_steps_per_episode

    # Define action and observation space
    self.action_space = spaces.Discrete(3)
    self.observation_space = spaces.Box(low=-1, high=1, shape=(self.parser.get_binary_features().shape[0],), dtype=np.float32)

  def get_valid_actions(self):
    valid_actions = self.parser.get_valid_actions()
    valid_actions_indexes = []
    if 'left_arc' in valid_actions:
      valid_actions_indexes.append(0)
    if 'right_arc' in valid_actions:
      valid_actions_indexes.append(1)
    if 'shift' in valid_actions:
      valid_actions_indexes.append(2)

    return valid_actions_indexes

  def step(self, action):
    self.current_step += 1
    valid_actions = self.get_valid_actions()
    # Map the action to the parser's functions
    if action == 0 and action in valid_actions:  # left_arc
      self.parser.left_arc()
    elif action == 1 and action in valid_actions:  # right_arc
      self.parser.right_arc()
    elif action == 2 and action in valid_actions:  # shift
      self.parser.shift()

    # Compute the reward for the current action
    reward, _ = self.computeReward(self.parser.stack, self.parser.buffer, self.parser.gold_tree, action, self.previous_action)

    # Update the previous action
    self.previous_action = [action, reward]

    # Check if the episode (parsing of one sentence) is done
    done = self.parser.is_tree_final()

    # Check if max steps per episode is reached
    truncated = False
    if self.current_step >= self.max_steps_per_episode:
      done = True
      truncated = True

    # Get the next state representation
    state = self.parser.get_binary_features().astype(np.float32)

    # Additional info can be added if necessary
    info = {}

    return state, reward, done, truncated, info

  def reset(self, seed=None, options=None):
    # Reset the state of the environment to an initial state
    self.parser = ArcStandard(self.sentence, self.tree)
    self.previous_action = [-1, 0]
    self.current_step = 0
    observation = self.parser.get_binary_features().astype(np.float32)
    info = {}  # Optional: can contain additional information
    return observation, info

  def render(self, mode='human', close=False):
    # Render the environment to the screen
    self.parser.print_configuration()

  def computeSimpleReward(self, stack, buffer, gold_tree, action, previous_action):
    # LEFT_ARC
    if action == 0:
      if len(stack) < 2 or (len(stack) == 2 and len(buffer) != 0) or stack[-2] == 0:
        return -100, False
      reward = 0
      s1 = stack[-1]
      s2 = stack[-2]

      if gold_tree[s2] == s1:
        reward += 2

      for i in stack:
        if gold_tree[i] == s2 or gold_tree[s2] == i:
          reward -= 1

      for i in buffer:
        if gold_tree[i] == s2 or gold_tree[s2] == i:
          reward -= 1

      return reward, False
    # RIGHT_ARC
    elif action == 1:
      if len(stack) < 2 or (stack[-2] == 0 and len(buffer) > 0):
        return -100, False
      reward = 0

      s1 = stack[-1]
      s2 = stack[-2]

      if gold_tree[s1] == s2:
        reward += 2

      for i in stack:
        if gold_tree[i] == s1 or gold_tree[s1] == i:
          reward -= 1

      for i in buffer:
        if gold_tree[i] == s1 or gold_tree[s1] == i:
          reward -= 1

      return reward, False
    # SHIFT
    elif action == 2:
      if len(buffer) == 0:
        return -100, False

      return 0, False

  def computeReward(self, stack, buffer, gold_tree, action, previous_action):
    # LEFT_ARC
    if action == 0:
      if len(stack) < 2 or (len(stack) == 2 and len(buffer) != 0) or stack[-2] == 0:
        return -100, False
      reward = 0
      s1 = stack[-1]
      s2 = stack[-2]

      if gold_tree[s2] == s1:
        reward += 2

      for i in stack:
        if gold_tree[i] == s2 or gold_tree[s2] == i:
          reward -= 1

      for i in buffer:
        if gold_tree[i] == s2 or gold_tree[s2] == i:
          reward -= 1

      if previous_action[0] == 2:
        reward -= previous_action[1]

      if reward == 1:
        reward = self.positive_reward

      return reward, False
    # RIGHT_ARC
    elif action == 1:
      if len(stack) < 2 or (stack[-2] == 0 and len(buffer) > 0):
        return -100, False
      reward = 0

      s1 = stack[-1]
      s2 = stack[-2]

      if gold_tree[s1] == s2:
        reward += 2

      for i in stack:
        if gold_tree[i] == s1 or gold_tree[s1] == i:
          reward -= 1

      for i in buffer:
        if gold_tree[i] == s1 or gold_tree[s1] == i:
          reward -= 1

      if previous_action[0] == 2:
        reward -= previous_action[1]

      if reward == 1:
        reward = self.positive_reward

      return reward, False
    # SHIFT
    elif action == 2:
      if len(buffer) == 0:
        return -100, False

      reward = 0
      s1 = stack[-1]

      for i in buffer:
        if gold_tree[i] == s1:
          return self.positive_reward, False # a right child allows a costless shift

      # s1 is a right child without right children
      if gold_tree[s1] < s1:
        b1 = buffer[0]
        sacrifice = 0
        # search for a lost father so that we can create an arc between s1 and the orphan node
        orphan = False
        father = gold_tree[b1]
        #print(b1, " ", father)
        while not orphan and father != 0:
          #print("QUA")
          flag = father in stack
          #print("flag: ", flag)
          if (father not in buffer and not flag):
            orphan = True
          if flag:
            #print("father in stack", father, )
            return -1, False
          #print("orphan: ", orphan)
          father = gold_tree[father]

        if orphan:
          #print("orphan")
          return 0, False

        for i in stack:
          if gold_tree[i] == b1 or gold_tree[b1] == i:
            sacrifice -= 1
        for i in buffer:
          if gold_tree[i] == b1 or gold_tree[b1] == i:
            sacrifice -= 1

        for i in stack:
          if gold_tree[i] == s1 or gold_tree[s1] == i:
            reward -= 1

        if reward == 0:
          return self.positive_reward, False

        return max(reward, sacrifice), False

      # s1 is a left child with no right children
      for i in stack:
        if gold_tree[i] == s1:
          reward -= 1

      if reward == 0:
        reward = self.positive_reward
      return reward, False

In [13]:
sentence = ['<ROOT>', 'Hello', 'World', '!']
gold_tree = [-1, 0, 1, 1]

env = DependencyParsingEnv(sentence, gold_tree)
# If the environment don't follow the interface, an error will be thrown
check_env(env, warn=True)

In [14]:
def evaluate(gold, preds):
  total = 0
  correct = 0

  for g, p in zip(gold, preds):
    for i in range(1,len(g)):
      total += 1
      if g[i] == p[i]:
        correct += 1

  return correct/total

In [15]:
EMBEDDING_SIZE = 200
LSTM_SIZE = 200
LSTM_LAYERS = 2
MLP_SIZE = 200
DROPOUT = 0.2
EPOCHS = 15
LR = 0.001   # learning rate
PROBABILITY_THRESHOLD = 0.1

In [16]:
class BilstmParser(nn.Module):

  def __init__(self, device):
    super(BilstmParser, self).__init__()
    self.device = device
    self.embeddings = nn.Embedding(len(emb_dictionary), EMBEDDING_SIZE, padding_idx=emb_dictionary["<pad>"])

    # initialize bi-LSTM
    self.lstm = nn.LSTM(EMBEDDING_SIZE, LSTM_SIZE, num_layers = LSTM_LAYERS, bidirectional=True, dropout=DROPOUT)

    # initialize feedforward
    self.w1 = torch.nn.Linear(8*LSTM_SIZE, MLP_SIZE, bias=True)
    self.activation = torch.nn.Tanh()
    self.w2 = torch.nn.Linear(MLP_SIZE, 3, bias=True)
    self.softmax = torch.nn.Softmax(dim=-1)

    self.dropout = torch.nn.Dropout(DROPOUT)

    #self.x = []
    self.h = torch.zeros(1,1,1)

  def forward(self, x, paths, flag_enc, flag_feat):
    if flag_enc:
      # get the embeddings
      x = [self.dropout(self.embeddings(torch.tensor(i).to(self.device))) for i in x]
      #if flag_feat:
      # run the bi-lstm
      self.h = self.lstm_pass(x) # size(longest_sentence, batch_size, features)

    # for each parser configuration that we need to score we arrange from the
    # output of the bi-lstm the correct input for the feedforward
    mlp_input = self.get_mlp_input(paths, self.h)

    # run the feedforward and get the scores for each possible action
    out = self.mlp(mlp_input)

    return out

  def lstm_pass(self, x):
    x = torch.nn.utils.rnn.pack_sequence(x, enforce_sorted=False)
    h, (h_0, c_0) = self.lstm(x)
    h, h_sizes = torch.nn.utils.rnn.pad_packed_sequence(h) # size h: (length_sentences, batch, output_hidden_units)
    return h

  def get_mlp_input(self, configurations, h):
    mlp_input = []
    zero_tensor = torch.zeros(2*LSTM_SIZE, requires_grad=False).to(self.device)
    for i in range(len(configurations)): # for every sentence in the batch
      mlp_input.append(torch.cat([zero_tensor if configurations[i][0] == -1 else h[configurations[i][0]][i], zero_tensor if configurations[i][1] == -1 else h[configurations[i][1]][i], zero_tensor if configurations[i][2]==-1 else h[configurations[i][2]][i], zero_tensor if configurations[i][3] == -1 else h[configurations[i][3]][i]]))
    mlp_input = torch.stack(mlp_input).to(self.device)
    return mlp_input

  def mlp(self, x):
    return self.softmax(self.w2(self.dropout(self.activation(self.w1(self.dropout(x))))))

  # we use this function at inference time. We run the parser and at each step
  # we pick as next move the one with the highest score assigned by the model
  def infere(self, x):

    parsers = [ArcStandard(i) for i in x]

    x = [self.embeddings(torch.tensor(i).to(self.device)) for i in x]

    h = self.lstm_pass(x)

    while not self.parsed_all(parsers):
      # get the current configuration and score next moves
      configurations = self.get_configurations(parsers)
      mlp_input = self.get_mlp_input(configurations, h)
      mlp_out = self.mlp(mlp_input)
      # take the next parsing step
      self.parse_step(parsers, mlp_out)

    # return the predicted dependency tree
    return [parser.arcs for parser in parsers]

  def get_configurations(self, parsers):
    configurations = []

    for parser in parsers:
      if parser.is_tree_final():
        conf = [-1, -1, -1, -1]
      else:
        if len(parser.stack) == 0:
          conf = [-1, -1, -1]
        elif len(parser.stack) == 1:
          conf = [-1, -1, parser.stack[-1]]
        elif len(parser.stack) == 2:
          conf = [-1, parser.stack[-2], parser.stack[-1]]
        else:
          conf = [parser.stack[-3], parser.stack[-2], parser.stack[-1]]
        if len(parser.buffer) == 0:
          conf.append(-1)
        else:
          conf.append(parser.buffer[0])
      configurations.append(conf)

    return configurations

  def parsed_all(self, parsers):
    for parser in parsers:
      if not parser.is_tree_final():
        return False
    return True

  # in this function we select and perform the next move according to the scores obtained.
  def parse_step(self, parsers, moves):
    moves_argm = moves.argmax(-1)
    for i in range(len(parsers)):
      if parsers[i].is_tree_final():
        continue
      else:
        if moves_argm[i] == 0:
          if parsers[i].stack[-2] != 0:
            parsers[i].left_arc()
          else:
            if len(parsers[i].buffer) > 0:
              parsers[i].shift()
            else:
              parsers[i].right_arc()
        elif moves_argm[i] == 1:
          if parsers[i].stack[-2] == 0 and len(parsers[i].buffer)>0:
            parsers[i].shift()
          else:
            parsers[i].right_arc()
        elif moves_argm[i] == 2:
          if len(parsers[i].buffer) > 0:
            parsers[i].shift()
          else:
            if moves[i][0] > moves[i][1]:
              if parsers[i].stack[-2] != 0:
                parsers[i].left_arc()
              else:
                parsers[i].right_arc()
            else:
              parsers[i].right_arc()

In [17]:
def find_min_indices(nums):
  min_value = min(nums)
  min_indices = [i for i, num in enumerate(nums) if num == min_value]
  return min_indices

def execute(parsers, actions, oracles, costs):
  for parser, action, oracle, cost in zip(parsers, actions, oracles, costs):
    if parser.is_tree_final():
      continue
    else:
      if action == 0:
        parser.left_arc()
        oracle.previous_action = [0, cost[0]]
      elif action == 1:
        parser.right_arc()
        oracle.previous_action = [1, cost[1]]
      elif action == 2:
        parser.shift()
        oracle.previous_action = [2, cost[2]]

def choose_next_amb(iteration, transition, min_cost):
  if transition in min_cost:
    return transition
  else:
    return min_cost[rd.randint(0, len(min_cost) - 1)]

def choose_next_exp(iteration, transition, min_cost):
  if iteration >= 1 and rd.random() > PROBABILITY_THRESHOLD:
    return transition
  else:
    return choose_next_amb(iteration, transition, min_cost)

def parsed_all(parsers):
  for parser in parsers:
    if not parser.is_tree_final():
      return False
  return True

In [18]:
def train(model, dataloader, criterion, optimizer, epoch, device):
  model.train()
  total_loss = 0
  count = 0
  error_count = 0

  # For each batch
  for batch in dataloader:
    # Extract sentence enconding, sentence itself and gold tree for each sentence in the batch
    enc_sentences, sentences, trees = batch
    # Reset the gradient for the current batch
    optimizer.zero_grad()
    # Containers to store transitions scores and respective gold labels
    global_transitions_scores = []
    global_gold_transitions = []

    # Flag to tell the model whether to save the encodings and the features tensor h for a new batch or not
    flag_enc = True
    flag_feat = True
    # Initialize a parser and an oracle for each sentence in the batch
    parsers = [ArcStandard(s) for s in sentences]

    # While each sentence hasn't been fully parsed
    while not parsed_all(parsers):

      # Save configuration: later we'll need the sequence of configurations in order to associate each one to the correct transition
      configurations = []
      for parser in parsers:
        if parser.is_tree_final():
          configurations.append([-1, -1, -1, -1])
        else:
          if len(parser.stack) == 0:
            configurations.append([-1, -1, -1])
          elif len(parser.stack) == 1:
            configurations.append([-1, -1, parser.stack[-1]])
          elif len(parser.stack) == 2:
            configurations.append([-1, parser.stack[-2], parser.stack[-1]])
          else:
            configurations.append([parser.stack[-3], parser.stack[-2], parser.stack[-1]])
          if len(parser.buffer) == 0:
            configurations[-1].append(-1)
          else:
            configurations[-1].append(parser.buffer[0])

      # The model produce the scores for each transition given the current configuration
      transitions_scores_tensor = model(enc_sentences, configurations, flag_enc, flag_feat)
      transitions_scores = transitions_scores_tensor.cpu().detach().numpy()
      flag_enc = False
      flag_feat = False

      # Cost of each transition for each current configuration
      costs = [oracle.provideTransitionCosts() for oracle in oracles]

      # Legal transitions for each current configuration
      legal_moves = [[index for index, value in enumerate(cost) if value != float('inf')] for cost in costs]

      # Legal transition with higher score according to the model for each current configuration
      predicted_transition = [moves[np.argmax([scores[i] for i in moves])] if not parser.is_tree_final() else -1 for scores, moves, parser in zip(transitions_scores, legal_moves, parsers)]

      # Collect the set of transitions with minimum cost for each current configuration
      min_cost_transitions = [find_min_indices(cost) for cost in costs]

      # Collect the best scoring transition among the ones with minimum cost for each current configuration
      best_min_cost_transitions = [
        max(
        (score, i) for i, score in enumerate(scoring_quadruplet) if i in min_cost_transition
        )[1]
        for scoring_quadruplet, min_cost_transition in zip(transitions_scores, min_cost_transitions)
      ]

      # Select a transition: the one predicted by the model or a randomly chosen one from the set of minimum cost transitions
      actual_transitions = [choose_next_exp(epoch, predicted_transition[i], min_cost_transitions[i]) for i in range(len(predicted_transition))]

      # Check if the predicted transition is among the ones with minimum cost: if not we need to update the model
      for i, parser in enumerate(parsers):
        if not parser.is_tree_final():# and predicted_transition[i] not in min_cost_transitions[i]:
          global_transitions_scores.append(transitions_scores_tensor[i])
          global_gold_transitions.append(best_min_cost_transitions[i])

      # Perform the decided transition
      execute(parsers, actual_transitions, oracles, costs)
      #executeExatDynamicOracle(actual_transitions, parsers, trees)

    total_loss += loss.item()
    count +=1

  return total_loss/count

In [19]:
def evaluateSingleTree(gold, preds):
  total = 0
  correct = 0

  for i in range(1,len(gold)):
    total += 1
    if gold[i] == preds[i]:
      correct += 1

  return correct/total

In [20]:
!pip install tensorboardX --no-cache-dir
!pip install tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [None]:
max_forked_episodes = 5
max_episode_length = 100
max_epochs= 15
#torch.autograd.set_detect_anomaly(True)

agent = Agent(gamma=0.9, epsilon=1.0, lr=5e-4, n_actions=3, input_dims=[153], mem_size=50000, batch_size=1000, eps_min=0.01, eps_dec=1e-7, replace=100)

for epoch in range(max_epochs):
  counter = 0
  tot_reward = 0
  tot_epsilon = 0
  for batch_data in bilstm_train_dataloader:
    enc_sentences, sentences, trees = batch_data
    for enc_sentence, sentence, tree in zip(enc_sentences, sentences, trees):
      for _ in range(max_forked_episodes):
        env = DependencyParsingEnv(sentence, tree, max_steps_per_episode=max_episode_length)
        state = env.reset()[0]
        for _ in range(max_episode_length):
          action = agent.choose_action(state)
          valid_actions = env.get_valid_actions()
          if np.random.rand() < 0.05 and len(valid_actions) != 0:  # Forking probability
            action = np.random.choice(valid_actions)
          next_state, reward, done, truncated, _ = env.step(action)
          agent.store_transition(state, action, reward, next_state, done)
          #print("Epsilon ", agent.epsilon)
          #print("Reward ", reward)
          counter += 1
          tot_epsilon += agent.epsilon
          tot_reward += reward
          agent.learn()
          state = next_state
          if done or truncated:
            break
        if done:
          break  # No more forking if the true end of the sentence is reached
  writer.add_scalar('epsilon', tot_epsilon / counter, epoch)
  print(tot_epsilon / counter)
  writer.add_scalar('reward', tot_reward / counter, epoch)
  print(tot_reward / counter)

  count = 0
  tot_loss = 0
  for batch_data in bilstm_train_dataloader:
    enc_sentences, sentences, trees = batch_data
    for enc_sentence, sentence, tree in zip(enc_sentences, sentences, trees):
      env = DependencyParsingEnv(sentence, tree, max_steps_per_episode=500)
      state = env.reset()[0]
      while not env.parser.is_tree_final():
        action = agent.choose_action(state)
        next_state, reward, done, truncated, _ = env.step(action)
        state = next_state
        if done or truncated or reward == -100:
          #print("Gold Tree: ", tree)
          #print("Parsed Tree: ", env.parser.arcs, "\n")
          count += 1
          tot_loss = evaluateSingleTree(tree, env.parser.arcs)
          break

  # Print epoch summary
  print(f'Epoch: {epoch}, UAS: {tot_loss/count}')


0.9746670018135641
-40.241363040319904
Epoch: 0, UAS: 0.0


In [25]:
!tensorboard --logdir runs # runs is the name of the folder that has summaries saved

2024-02-15 18:35:31.478663: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-15 18:35:31.478732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-15 18:35:31.480146: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.2 at http://localhost:6006/ (Press CTRL+C to quit)
