In [1]:
!pip install -U conllu datasets



In [2]:
import torch
from torch.utils.data import DataLoader
import transformers
from transformers import BertTokenizer, BertModel
import numpy as np
import pandas as pd
from datasets import load_dataset

In [3]:
english_data = load_dataset("universal_dependencies", "en_ewt")
german_data = load_dataset("universal_dependencies", "de_gsd")

print(english_data)
print(german_data)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 12543
    })
    validation: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 2002
    })
    test: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 2077
    })
})
DatasetDict({
    train: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 13814
    })
    validation: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],
        num_rows: 799
    })
    test: Dataset({
        features: ['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'],


In [42]:
class ArcStandardParser:
    def __init__(self):
        self.stack = []
        self.buffer = []
        self.dependencies = []

    def initialize(self, sentence):
        self.stack = [(0, "ROOT")]  # Root represented as index 0
        self.buffer = [(i + 1, token) for i, token in enumerate(sentence)]
        self.dependencies = []

    def shift(self):
        if self.buffer:
            self.stack.append(self.buffer.pop(0))

    def left_arc(self):
        if len(self.stack) >= 2:
            dependent = self.stack.pop(-2)
            head = self.stack[-1]
            self.dependencies.append((head[0], dependent[0]))

    def right_arc(self):
        if len(self.stack) >= 2:
            dependent = self.stack.pop(-1)
            head = self.stack[-1]
            self.dependencies.append((head[0], dependent[0]))

    def parse_step(self, action):
        if action == "SHIFT":
            self.shift()
        elif action == "LEFT-ARC":
            self.left_arc()
        elif action == "RIGHT-ARC":
            self.right_arc()

    def get_state(self):
        return {
            "stack": self.stack,
            "buffer": self.buffer,
            "dependencies": self.dependencies,
        }


In [41]:
def dependency_to_actions(sentence, heads):
    stack = [(0, "ROOT")]
    buffer = []
    for i, token in enumerate(sentence):
        buffer.append((i + 1, token))
    actions = []

    dependents = {}
    for i in range(len(sentence) + 1):  # Include ROOT
        dependents[i] = []

    for i, head in enumerate(heads):
        dependent_token_index = i + 1  # 1-based index for tokens
        dependents[head].append(dependent_token_index)

    created_arcs = set()

    while buffer or len(stack) > 1:
        if len(stack) >= 2:
            s1 = stack[-1][0]
            s2 = stack[-2][0]

            if s1 in dependents[s2]:
                all_dependents_processed = True
                for d in dependents[s1]:
                    if d not in created_arcs:
                        all_dependents_processed = False
                        break
                if all_dependents_processed:
                    actions.append("RIGHT-ARC")
                    created_arcs.add(s1)
                    stack.pop()
                    continue

            if s2 in dependents[s1]:
                all_dependents_processed = True
                for d in dependents[s2]:
                    if d not in created_arcs:
                        all_dependents_processed = False
                        break
                if all_dependents_processed:
                    actions.append("LEFT-ARC")
                    created_arcs.add(s2)
                    stack.pop(-2)
                    continue

        if buffer:
            actions.append("SHIFT")
            next_token = buffer.pop(0)
            stack.append(next_token)
        else:
            break

    return actions

In [43]:
def test_parser():
    sentence = ["I", "like", "cats"]
    heads = [2, 0, 2]  # 1-based indices: I->like, like->ROOT, cats->like

    print("Testing dependency_to_actions:")
    print(f"Sentence: {sentence}")
    print(f"Heads: {heads}")

    actions = dependency_to_actions(sentence, heads)
    print(f"Generated actions: {actions}")

    parser = ArcStandardParser()
    parser.initialize(sentence)
    print("\nRunning actions through parser:")
    print(f"Initial state: {parser.get_state()}")

    for action in actions:
        parser.parse_step(action)
        print(f"After {action}: {parser.get_state()}")

# Run test
test_parser()

Testing dependency_to_actions:
Sentence: ['I', 'like', 'cats']
Heads: [2, 0, 2]
Generated actions: ['SHIFT', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'RIGHT-ARC', 'RIGHT-ARC']

Running actions through parser:
Initial state: {'stack': [(0, 'ROOT')], 'buffer': [(1, 'I'), (2, 'like'), (3, 'cats')], 'dependencies': []}
After SHIFT: {'stack': [(0, 'ROOT'), (1, 'I')], 'buffer': [(2, 'like'), (3, 'cats')], 'dependencies': []}
After SHIFT: {'stack': [(0, 'ROOT'), (1, 'I'), (2, 'like')], 'buffer': [(3, 'cats')], 'dependencies': []}
After LEFT-ARC: {'stack': [(0, 'ROOT'), (2, 'like')], 'buffer': [(3, 'cats')], 'dependencies': [(2, 1)]}
After SHIFT: {'stack': [(0, 'ROOT'), (2, 'like'), (3, 'cats')], 'buffer': [], 'dependencies': [(2, 1)]}
After RIGHT-ARC: {'stack': [(0, 'ROOT'), (2, 'like')], 'buffer': [], 'dependencies': [(2, 1), (2, 3)]}
After RIGHT-ARC: {'stack': [(0, 'ROOT')], 'buffer': [], 'dependencies': [(2, 1), (2, 3), (0, 2)]}


In [52]:
def clean_dataset(sentences, heads):
    """Clean and validate the dataset."""
    cleaned_sentences = []
    cleaned_heads = []
    skipped_indices = []

    for i, (sentence, head) in enumerate(zip(sentences, heads)):
        try:
            head_as_integers = []
            for h in head:
                head_as_integers.append(int(h))
            head = head_as_integers

            head_indices_are_valid = True
            for h in head:
                if h < 0 or h > len(sentence):
                    head_indices_are_valid = False
                    break
            if not head_indices_are_valid:
                skipped_indices.append(i)
                continue

            root_indices = []
            for idx, h in enumerate(head):
                if h == 0:
                    root_indices.append(idx)

            if len(root_indices) != 1:
                skipped_indices.append(i)
                continue

            invalid_head_found = False
            for h in head:
                if h > len(sentence):
                    invalid_head_found = True
                    break
            if invalid_head_found:
                skipped_indices.append(i)
                continue

            cleaned_sentences.append(sentence)
            cleaned_heads.append(head)

        except Exception as e:
            skipped_indices.append(i)

    if len(cleaned_sentences) == 0:
        print("\nExample of first few skipped sentences:")
        for idx in skipped_indices[:3]:
            print(f"\nSentence {idx}:")
            print(f"Text: {sentences[idx]}")
            print(f"Heads: {heads[idx]}")

    return cleaned_sentences, cleaned_heads, skipped_indices

In [53]:
def extract_train_data(dataset):
    train_data = dataset['train']
    sentences = train_data['tokens']
    heads = train_data['head']

    if not isinstance(sentences[0], list):
        sentences = [list(s) for s in sentences]
    if not isinstance(heads[0], list):
        heads = [list(h) for h in heads]
    return sentences, heads

def extract_val_data(dataset):
    val_data = dataset['validation']
    sentences = val_data['tokens']
    heads = val_data['head']
    if not isinstance(sentences[0], list):
        sentences = [list(s) for s in sentences]
    if not isinstance(heads[0], list):
        heads = [list(h) for h in heads]
    return sentences, heads

def extract_test_data(dataset):
    test_data = dataset['test']
    sentences = test_data['tokens']
    heads = test_data['head']
    if not isinstance(sentences[0], list):
        sentences = [list(s) for s in sentences]
    if not isinstance(heads[0], list):
        heads = [list(h) for h in heads]
    return sentences, heads

In [45]:
# Process full datasets
train_sentences, train_heads = extract_train_data(english_data)
clean_train_sent, clean_train_heads, skipped_train = clean_dataset(train_sentences, train_heads)

val_sentences, val_heads = extract_val_data(english_data)
clean_val_sent, clean_val_heads, skipped_val = clean_dataset(val_sentences, val_heads)

test_sentences, test_heads = extract_test_data(english_data)
clean_test_sent, clean_test_heads, skipped_test = clean_dataset(test_sentences, test_heads)

print("\nFinal Dataset Sizes:")
print(f"Train: {len(clean_train_sent)} sentences")
print(f"Validation: {len(clean_val_sent)} sentences")
print(f"Test: {len(clean_test_sent)} sentences")


Final Dataset Sizes:
Train: 10447 sentences
Validation: 1722 sentences
Test: 1790 sentences


In [10]:
# Test with just first few examples
sample_sentences = train_sentences[:5]
sample_heads = train_heads[:5]

print("Testing cleaning with first 5 sentences:")
clean_test_sent, clean_test_heads, skipped = clean_dataset(sample_sentences, sample_heads)

for i, (sent, heads) in enumerate(zip(clean_test_sent, clean_test_heads)):
    print(f"\nSentence {i}:")
    print(f"Text: {sent}")
    print(f"Heads: {heads}")

Testing cleaning with first 5 sentences:
Processed 5 sentences
Kept 5 sentences
Skipped 0 sentences

Sentence 0:
Text: ['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.']
Heads: [0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8, 18, 18, 7, 21, 21, 18, 23, 21, 21, 28, 28, 28, 21, 1]

Sentence 1:
Text: ['[', 'This', 'killing', 'of', 'a', 'respected', 'cleric', 'will', 'be', 'causing', 'us', 'trouble', 'for', 'years', 'to', 'come', '.', ']']
Heads: [10, 3, 10, 7, 7, 7, 3, 10, 10, 0, 10, 10, 14, 10, 16, 14, 10, 10]

Sentence 2:
Text: ['DPA', ':', 'Iraqi', 'authorities', 'announced', 'that', 'they', 'had', 'busted', 'up', '3', 'terrorist', 'cells', 'operating', 'in', 'Baghdad', '.']
Heads: [0, 1, 4, 5, 1, 9, 9, 9, 5, 9, 13, 13, 9, 13, 16, 14, 1]

Sentence 3:
Text: ['Two', 'of', 'them', 'were', 'being', 'run', 'by', '2', 'off

In [46]:
def generate_training_data(sentences, heads):
    training_pairs = []
    for sent, head in zip(sentences, heads):
        try:
            actions = dependency_to_actions(sent, head)
            training_pair = (sent, actions)
            training_pairs.append(training_pair)

        except Exception as e:
            print(f"Error: {e}")
            continue

    return training_pairs


In [47]:
print("Generating training sequences...")
train_data = generate_training_data(clean_train_sent, clean_train_heads)
val_data = generate_training_data(clean_val_sent, clean_val_heads)
test_data = generate_training_data(clean_test_sent, clean_test_heads)

print(f"Training sequences: {len(train_data)}")
print(f"Validation sequences: {len(val_data)}")
print(f"Test sequences: {len(test_data)}")

print("\nExample training sequence:")
example_sent, example_actions = train_data[0]
print(f"Sentence: {example_sent}")
print(f"Action sequence: {example_actions}")

Generating training sequences...
Training sequences: 10447
Validation sequences: 1722
Test sequences: 1790

Example training sequence:
Sentence: ['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.']
Action sequence: ['SHIFT', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'LEFT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC'

In [13]:
print(train_data[:5])

[(['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'], ['SHIFT', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'LEFT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'RIGHT-ARC']), (['[', 'This', 'killing', 'of', 'a', 'respected', 'cleric', 'will', 'be', 'causing', 'us', 'trouble', 'for', 'years', 'to', 'come', '.', ']'],

In [14]:
def verify_action_sequence(sentence, heads, actions):
    """
    Verify that an action sequence produces the correct dependencies.
    """
    parser = ArcStandardParser()
    parser.initialize(sentence)

    for action in actions:
        parser.parse_step(action)

    dependencies = parser.get_state()["dependencies"]

    predicted_heads = [0] * len(sentence)
    for head_idx, dependent_idx in dependencies:
        if head_idx == 0:  # ROOT
            predicted_heads[dependent_idx - 1] = 0
        else:
            predicted_heads[dependent_idx - 1] = head_idx

    print("Original heads:", heads)
    print("Predicted heads:", predicted_heads)
    print("Match:", predicted_heads == heads)
    return predicted_heads == heads

# Test
sentence = ['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.']
heads = [0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8, 18, 18, 7, 21, 21, 18, 23, 21, 21, 28, 28, 28, 21, 1]

# Test with the example
example_sent, example_actions = train_data[0]
example_heads = clean_train_heads[0]

actions = dependency_to_actions(example_sent, example_heads)
verify_action_sequence(example_sent, example_heads, actions)

Original heads: [0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8, 18, 18, 7, 21, 21, 18, 23, 21, 21, 28, 28, 28, 21, 1]
Predicted heads: [0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8, 18, 18, 7, 21, 21, 18, 23, 21, 21, 28, 28, 28, 21, 1]
Match: True


True

In [15]:
print(actions)

['SHIFT', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'SHIFT', 'SHIFT', 'SHIFT', 'SHIFT', 'LEFT-ARC', 'LEFT-ARC', 'LEFT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'RIGHT-ARC', 'SHIFT', 'RIGHT-ARC', 'RIGHT-ARC']


In [16]:
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

In [50]:
class DependencyParsingDataset(Dataset):
    def __init__(self, sentences, heads, tokenizer):
        self.sentences = sentences
        self.heads = heads
        self.tokenizer = tokenizer

        self.actions = []
        for s, h in zip(sentences, heads):
            action_sequence = dependency_to_actions(s, h)
            action_sequence_as_int = self.convert_actions_to_int(action_sequence)
            self.actions.append(action_sequence_as_int)

    def convert_actions_to_int(self, actions):
        action_map = {
            "SHIFT": 0,
            "LEFT-ARC": 1,
            "RIGHT-ARC": 2
        }

        actions_as_int = []
        for action in actions:
            actions_as_int.append(action_map[action])

        return actions_as_int

    def __len__(self):
        dataset_length = len(self.sentences)
        return dataset_length

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        actions = self.actions[idx]

        sentence_as_string = " ".join(sentence)
        inputs = self.tokenizer(
            sentence_as_string,
            return_tensors="pt",
            padding=True,
            truncation=True
        )

        token_count = inputs["input_ids"].size(1)

        if len(actions) < token_count:
            padded_actions = actions + [0] * (token_count - len(actions))
            actions = padded_actions
        elif len(actions) > token_count:
            truncated_actions = actions[:token_count]
            actions = truncated_actions

        output = {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "actions": torch.tensor(actions, dtype=torch.long)
        }

        return output


In [51]:
class DependencyParserModel(nn.Module):
    def __init__(self, bert_model_name, num_actions):
        super(DependencyParserModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_actions)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        token_outputs = outputs.last_hidden_state
        action_logits = self.fc(token_outputs)
        return action_logits


In [48]:
def train_model(model, dataloader, optimizer, criterion, num_epochs=3):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            actions = batch["actions"].to(device)

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)

            batch_size, seq_len, num_actions = logits.size()
            logits = logits.view(-1, num_actions)
            actions = actions.view(-1)

            valid_indices = attention_mask.view(-1).nonzero(as_tuple=True)[0]
            logits = logits[valid_indices]
            actions = actions[valid_indices]

            loss = criterion(logits, actions)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")


In [49]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            actions = batch["actions"].to(device)

            logits = model(input_ids, attention_mask)
            predictions = torch.argmax(logits, dim=2)

            valid_indices = attention_mask.view(-1).nonzero(as_tuple=True)[0]
            predictions = predictions.view(-1)[valid_indices]
            actions = actions.view(-1)[valid_indices]

            correct += (predictions == actions).sum().item()
            total += actions.size(0)

    accuracy = correct / total
    print(f"Evaluation Accuracy: {accuracy:.4f}")
    return accuracy


In [21]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
dataset = DependencyParsingDataset(clean_train_sent, clean_train_heads, tokenizer)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [23]:
num_actions = 3
model = DependencyParserModel("bert-base-uncased", num_actions).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

In [24]:
for i, batch in enumerate(dataloader):
    print(f"Batch {i + 1}:")
    print("Input IDs shape:", batch["input_ids"].shape)
    print("Attention Mask shape:", batch["attention_mask"].shape)
    print("Actions shape:", batch["actions"].shape)
    print("-" * 40)
    if i == 2:  # Limit to first 3 batches
        break


Batch 1:
Input IDs shape: torch.Size([1, 17])
Attention Mask shape: torch.Size([1, 17])
Actions shape: torch.Size([1, 17])
----------------------------------------
Batch 2:
Input IDs shape: torch.Size([1, 11])
Attention Mask shape: torch.Size([1, 11])
Actions shape: torch.Size([1, 11])
----------------------------------------
Batch 3:
Input IDs shape: torch.Size([1, 7])
Attention Mask shape: torch.Size([1, 7])
Actions shape: torch.Size([1, 7])
----------------------------------------


In [25]:
for i, batch in enumerate(dataloader):
    print(f"Batch {i + 1}:")
    print("Input IDs:", batch["input_ids"])
    print("Attention Mask:", batch["attention_mask"])
    print("Actions:", batch["actions"])
    print("-" * 40)
    if i == 2:
        break


Batch 1:
Input IDs: tensor([[  101,  2023, 16350, 21500,  8636,  2974,  3084,  2825,  2309,  3529,
         23730,  1997, 21500,  3401, 21823,  9289,  2015,  1012,   102]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Actions: tensor([[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 2, 0, 0, 1, 0, 1, 0]])
----------------------------------------
Batch 2:
Input IDs: tensor([[ 101, 9278, 2482, 1004, 3298, 2000, 2034, 3309, 1006, 2625, 2084, 1015,
         3178, 2185, 1007, 1025, 7859,  102]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Actions: tensor([[0, 0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 2, 0, 0, 1]])
----------------------------------------
Batch 3:
Input IDs: tensor([[  101, 14163,  7377, 11335,  2546,  2253,  2006,  2000,  2377,  4517,
         20911, 21530,  2007,  2634,  1999,  2526,  1010, 26875,  2162,  3807,
          2008,  2095,  1012,   102]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [38]:
train_model(model, dataloader, optimizer, criterion, num_epochs=3)

Epoch 1/3, Loss: 0.25821108132474824
Epoch 2/3, Loss: 0.2262984521646665
Epoch 3/3, Loss: 0.20295755827890913


In [39]:
evaluate_model(model, dataloader)

Evaluation Accuracy: 0.8892


0.8892249811888638