In [None]:
import json
import torch
from transformers import (
    AutoTokenizer,
    AutoConfig,
    T5ForConditionalGeneration,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset
from collators import DataCollatorForT5MLM

# To resume from checkpoint with PyTorch > 2.6
# Allow torch.load to unpickle numpy._core.multiarray._reconstruct
from numpy._core.multiarray import _reconstruct
from numpy import ndarray
torch.serialization.add_safe_globals([_reconstruct, ndarray])

# Load configuration from JSON file.
with open("./configs/00_config.json", "r") as f:
    config_args = json.load(f)

dataset_cache_path =  config_args.get("dataset")
# dataset_cache_path = config_args.get("dataset_cache_path", "/work/datasets/Skylion007___openwebtext")
max_seq_length = config_args.get("max_seq_length", 512)

# 1. Load preprocessed dataset from disk
print(f"Loading dataset from cache: {dataset_cache_path}")
grouped_dataset = load_dataset(dataset_cache_path)

# 2. Tokenizer initialization.
tokenizer = AutoTokenizer.from_pretrained(config_args["model_name_or_path"])
tokenizer.pad_token = tokenizer.eos_token  # T5 often uses EOS as PAD.

In [None]:
print(len(grouped_dataset['train']))

In [None]:
for i, sample in enumerate(grouped_dataset['test']):
    pass
    print(i)

In [None]:
print(len(grouped_dataset['test']))

In [None]:
for sample in grouped_dataset['train']['text'][:1]:
    print(sample)
    # print(tokenizer.decode(sample))

In [None]:
from collators import DataCollatorForT5MLM

data_collator = DataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=config_args.get("mlm_probability", 0.15),
    mean_noise_span_length=config_args.get("mean_noise_span_length", 3.0),
    input_length=max_seq_length,
    target_length=max_seq_length,  # Adjust if needed.
)

# 4. Load model configuration and model.
# config = AutoConfig.from_pretrained(config_args["model_name_or_path"])
model = T5ForConditionalGeneration.from_pretrained("/home/vejvar-martin-nj/git/picard/results/t5/unbiased-openwebtext/checkpoint-45200")

# 5. Set up TrainingArguments.
training_args = TrainingArguments(
    output_dir=config_args["output_dir"],
    do_train=config_args.get("do_train", True),
    do_eval=config_args.get("do_eval", True),
    num_train_epochs=config_args.get("num_train_epochs", 4),
    per_device_train_batch_size=config_args.get("per_device_train_batch_size", 12),
    per_device_eval_batch_size=config_args.get("per_device_eval_batch_size", 16),
    gradient_accumulation_steps=config_args.get("gradient_accumulation_steps", 4),
    learning_rate=config_args.get("learning_rate", 3e-4),
    lr_scheduler_type=config_args.get("lr_scheduler_type", "constant"),
    weight_decay=config_args.get("weight_decay", 0.01),
    logging_steps=config_args.get("logging_steps", 40),
    eval_steps=config_args.get("eval_steps", 200),
    save_steps=config_args.get("save_steps", 400),
    overwrite_output_dir=config_args.get("overwrite_output_dir", False),
    resume_from_checkpoint=config_args.get("resume_from_checkpoint", False),
    ignore_data_skip=config_args.get("ignore_data_skip", False),
    save_total_limit=config_args.get("save_total_limit", 2),
    seed=config_args.get("seed", 42),
    load_best_model_at_end=True,
    evaluation_strategy="steps",
    metric_for_best_model="loss",
    greater_is_better=False,
    fp16=True,  # Recommended if GPU supports mixed precision
    logging_dir="../logs",
    report_to=["wandb"],
)

In [None]:
# 6. Create the Trainer.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=grouped_dataset["train"],
    eval_dataset=grouped_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
from torch.utils.data import DataLoader

# Select a small sample from the training dataset
sample_dataset = grouped_dataset["train"].select(range(20))

# Create a DataLoader with the data collator
sample_loader = DataLoader(sample_dataset, batch_size=8, collate_fn=data_collator)

for batch in sample_loader:
    # Move batch to the appropriate device
    batch = {k: v.to(model.device) for k, v in batch.items()}

    # Generate predictions
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            max_length=128,
            num_beams=4
        )

    # Decode inputs, predictions, and labels
    decoded_inputs = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)
    decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
    decoded_labels = tokenizer.batch_decode(batch["labels"], skip_special_tokens=False)

    # Display the results
    for inp, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
        print(f"Input:\n{inp}\n\nTarget:\n{label}\n\nPrediction:\n{pred}\n{'-'*40}")
    break  # Remove this break to process all batches

In [1]:
import json
import torch
from transformers import (
    AutoTokenizer,
)
from datasets import load_from_disk

# To resume from checkpoint with PyTorch > 2.6
# Allow torch.load to unpickle numpy._core.multiarray._reconstruct
from numpy._core.multiarray import _reconstruct
from numpy import ndarray
torch.serialization.add_safe_globals([_reconstruct, ndarray])

# dataset_cache_path = config_args.get("dataset_cache_path", "/work/datasets/Skylion007___openwebtext")

# Load configuration from JSON file.
with open("./configs/00_config.json", "r") as f:
    config_args = json.load(f)

dataset_cache_path = config_args.get("dataset_cache_path", "/work/datasets/owt-10k-clean")
max_seq_length = config_args.get("max_seq_length", 512)

# 1. Load preprocessed dataset from disk
print(f"Loading dataset from cache: {dataset_cache_path}")
grouped_dataset = load_from_disk(dataset_cache_path)

# 2. Tokenizer initialization.
tokenizer = AutoTokenizer.from_pretrained(config_args["model_name_or_path"])
tokenizer.pad_token = tokenizer.eos_token  # T5 often uses EOS as PAD.

Loading dataset from cache: /work/datasets/openwebtext-clean


Loading dataset from disk:   0%|          | 0/45 [00:00<?, ?it/s]

In [2]:
print(len(grouped_dataset['train']))
for ex in grouped_dataset['test']:
    print(len(tokenizer.decode(ex['input_ids'], skip_special_tokens=False)))
    break

8684583
2383


In [None]:
import re
from data_utils import filter_sql_queries, QL_PATTERNS

# Combined pattern used in the filter function
pattern = re.compile("|".join(QL_PATTERNS), re.IGNORECASE)
# pattern = re.compile(
#     r"\bSELECT\b.*\bFROM\b|\bMATCH\b.*\bRETURN\b|\bselect\b.*\bwhere\b",
#     re.IGNORECASE,
# )

def contains_query_language(example_text):
    return not bool(pattern.search(example_text))

# Check for any remaining occurrences in the filtered dataset:
split = 'test'
remaining = []
for ex in grouped_dataset[split]:
    inp = {"text": tokenizer.decode(ex['input_ids'])}
    if contains_query_language(inp["text"]):
        remaining.append(inp["text"])
        # print(remaining)
    
# remaining = [ex for ex in grouped_dataset["test"] if contains_query_language(tokenizer.decode(ex['input_ids']))]
print(f"Remaining examples with query language mentions: {len(remaining)}")

In [None]:
print(f"Remaining examples with query language mentions: {len(remaining)}")