In [12]:
import torch
import random
import numpy as np
from transformers import T5Tokenizer
from data.preprocessing import DataPreprocessor
from data.dataset import TextHumanizerDataset
from model.model import TextHumanizerModel
from model.trainer import Trainer
from config import Config
import os

In [3]:
def set_seed(seed):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [4]:
config = Config()

set_seed(config.RANDOM_SEED)

In [5]:
if not os.path.exists(config.MODEL_SAVE_PATH):
    os.makedirs(config.MODEL_SAVE_PATH)

In [7]:
tokenizer = T5Tokenizer.from_pretrained(config.MODEL_NAME)

In [8]:
print("Preparing dataset...")
preprocessor = DataPreprocessor()
datasets = preprocessor.prepare_dataset()

Preparing dataset...
Dataset loaded from Hugging Face
Train: 16706, Validation: 3580, Test: 3581


In [18]:
# Show the first 5 rows of the training dataset as DataFrame
datasets['train'].to_pandas()

Unnamed: 0,question,ai_text,human_text
0,How do companies go public ? And more specific...,Sure! A company can go public by selling share...,The biggest driver behind going public is the ...
1,How do I choose 401k investment funds?,Choosing the right investment funds for your 4...,I disagree strongly with chasing expenses. Don...
2,What are my risks of early assignment?,Early assignment refers to the process of bein...,"The put vs call assignment risk, is actually t..."
3,Why does running a strong enough magnet over a...,Hard drives store information on spinning disk...,The way a hard disk drive works is that it has...
4,How does the government track people downloadi...,The government can track people who download i...,I am not aware of any evidence of any western ...
...,...,...,...
16701,Why congressmen all over the world refer to ea...,In legislative bodies like Congress or Parliam...,The legislators are not directly addressing an...
16702,The concept of abstraction in computer science...,Abstraction in computer science is a way of si...,Abstraction is the idea that the user of some ...
16703,How can trading in General Motors stock be sus...,Trading in General Motors (GM) stock can be su...,The exchange always briefly suspends trading i...
16704,"Why are some foods "" Breakfast Foods "" and oth...",Some foods are traditionally eaten at certain ...,If you travel to parts of Asia you will find t...


In [19]:
# Create datasets
train_dataset = TextHumanizerDataset(datasets['train'], tokenizer)
val_dataset = TextHumanizerDataset(datasets['validation'], tokenizer)

In [20]:
# Initialize model
print("Initializing model...")
model = TextHumanizerModel(config)

Initializing model...


In [22]:
# Initialize trainer
trainer = Trainer(model, tokenizer, train_dataset, val_dataset, config)

In [None]:
# Run here to start training
print("Starting training...")
trainer.train()

print("Training complete!")