In [2]:
from conllu import parse_incr
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
from collections import deque

In [None]:
class DependencyParsingDataset(Dataset):
    def __init__(self, file_path):
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as file:
            for sentence in parse_incr(file):
                #print(sentence)
                transitions = self.generate_transitions(sentence)
                data.append(transitions)
        return data


    def oracle(self, stack, buffer, sentence):
        if len(stack)<2 :
            return 'SHIFT'  # This ensures we don't try to access buffer[0] when buffer is empty
        #print(stack,buffer)
        top_of_stack = stack[-1] if stack else None
        first_in_buffer = buffer[0] if buffer else None

        if top_of_stack is not None and first_in_buffer is not None:
            buffer_head_idx = sentence[first_in_buffer - 1]['head']  # Adjusting index for zero-based list access
            stack_head_idx = sentence[top_of_stack - 1]['head']      # Adjusting index for zero-based list access

            if buffer_head_idx == top_of_stack:
                return 'RIGHT-ARC'
            elif stack_head_idx == first_in_buffer:
                return 'LEFT-ARC'

        return 'SHIFT'
    

    def generate_transitions(self, sentence):
        transitions = []
        stack = [0]  # Start with ROOT at the stack


            # Initialize buffer to handle multi-word tokens and null tokens
        buffer = deque()
        for token in sentence:
            if isinstance(token['id'], tuple) and token['form'] == '-':
                continue  # Ignore null tokens if represented by '-'
            elif isinstance(token['id'], tuple):
                buffer.append(token['id'][0])  # Use the first index from the tuple for multi-word tokens
            else:
                buffer.append(token['id'])

    

        arcs = []  #(dep,head)

        while buffer:
            action = self.oracle(stack, buffer, sentence)

            features = self.extract_features(stack, buffer, sentence)
            transitions.append((features, action))
        
        
            if action == 'SHIFT':
                stack.append(buffer.popleft())
            elif action == 'LEFT-ARC':
                arcs.append((stack[-1], buffer[0]))
                stack.pop()
            elif action == 'RIGHT-ARC' :
                arcs.append((buffer[0], stack[-1]))
                buffer.popleft()

        return transitions

    def extract_features(self, stack, buffer, sentence):
    # Initialize default features
        features = {
        'stack_top_id': 0, 'buffer_first_id': 0,
        'stack_top_word': 'NULL', 'buffer_first_word': 'NULL',
        'stack_top_pos': 'NULL', 'buffer_first_pos': 'NULL'
        }

    # Check and assign the top of the stack features
        if stack:
            stack_top_idx = stack[-1]
            stack_top_token = sentence[stack_top_idx - 1]  # Adjust for zero indexing
            features['stack_top_id'] = stack_top_idx
            features['stack_top_word'] = stack_top_token['form'].lower()
            features['stack_top_pos'] = stack_top_token['upos']

    # Check and assign the first item in the buffer features
        if buffer:
            buffer_first_idx = buffer[0]
            buffer_first_token = sentence[buffer_first_idx - 1]  # Adjust for zero indexing
            features['buffer_first_id'] = buffer_first_idx
            features['buffer_first_word'] = buffer_first_token['form'].lower()
            features['buffer_first_pos'] = buffer_first_token['upos']

        return features


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
    # Retrieve the sentence data (list of tuples)
        sentence_data = self.data[idx]
    
    # You might want to process each token in the sentence. 
    # Here's an example of how you could handle this:
        processed_data = []
        for token in sentence_data:
            if len(token) == 2:
                features, action = token
                processed_data.append((features, action))
            else:
                raise ValueError(f"Expected each token to be a tuple of 2 elements, got {len(token)} elements.")
    
    # Return the processed list of tokens
        return processed_data