In [None]:
from Enviroment.BaseTaggerEnv.environment import PosCorrectionEnv
from ..Utils.dataset import (
    brown_to_training_data,
    get_brown_as_universal,
    get_brown_sentences,
    get_tag_list,
)
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
)
import torch
import gymnasium as gym
from torch.utils.data import Dataset

In [None]:
dataset = get_brown_as_universal()
# sents = get_brown_sentences()
print(dataset[0:10])
# print(f"Number of tagged sentences: {len(dataset)}")
# print(sents[0:10])

tag_list = get_tag_list()
print(f"Tag list: {tag_list}, Number of tags: {len(tag_list)}")

In [None]:
# --- 1. Dataset Setup ---
# Format: (Sentence, [List of Universal Tags])
training_data = brown_to_training_data(
    [dataset[i] for i in range(1000)]
)  # Using first 1000 sentences

# --- 2. Train and Save Model using Utils ---
from Algorithms.Utils.model_training import train_pos_tagger

# This function handles tokenization, dataset creation, training, and saving
train_pos_tagger(training_data, tag_list, output_dir="./pos_tagger_model", epochs=3)

In [None]:
# --- 2. Registration ---
gym.register(
    id="gymnasium_env/PosCorrection-v0",
    entry_point=PosCorrectionEnv,
    kwargs={
        "dataset": training_data,
        "nlp_model": "en_core_web_sm",
    },  # Pass dataset to constructor
)

# --- 3. Create and Test ---
env = gym.make("gymnasium_env/PosCorrection-v0")

# def

obs, info = env.reset(seed=42)

print(f"Start Word: '{info['word']}' | Base Tag: {info['base_tag']}")
print(f"Observation Keys: {obs.keys()}")

done = False
total_reward = 0

print("\n--- Starting Episode ---")
while not done:
    # Random Agent for demonstration
    # action = env.action_space.sample()
    action = 0  # Always KEEP

    # Execute Step
    obs, reward, terminated, truncated, info = env.step(action)

    action_name = ["KEEP", "SHIFT", "DICT"][action]
    word = info.get("word", "END")

    print(f"Action: {action_name: <5} | Word: {word: <10} | Reward: {reward:.1f}")

    total_reward += reward
    done = terminated or truncated

print(f"Episode Finished. Total Reward: {total_reward}")

In [None]:
# --- Example: Loading and Using the Saved Model with Utils ---
from Algorithms.Utils.model_inference import PosTaggerInference

# 1. Initialize the inference class
saved_model_path = "./pos_tagger_model"
inference = PosTaggerInference(saved_model_path, tag_list)

# 2. Predict for a sentence
text = "Reinforcement learning is fascinating."
results = inference.predict(text)

# 3. Display
print(f"Input: {text}")
for token, tag in results:
    print(f"{token:<15} -> {tag}")