In [1]:
# Imports and Setup
import pandas as pd
import numpy as np
from pathlib import Path
import sys
import logging
import json
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add project root to path
sys.path.append(str(Path.cwd().parent))

# Import custom modules
from src.models.bert_model import TrollDetector
from src.models.trainer import TrollDetectorTrainer
from src.data_tools.dataset import TrollDataset, collate_batch

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
# Define paths
DATA_DIR = Path('data')
PROCESSED_DATA_DIR = DATA_DIR / 'processed'
CHECKPOINT_DIR = Path('./checkpoints')

# Create checkpoint directory
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# Updated training configuration
config = {
    'model_name': 'distilbert-base-multilingual-cased',
    'adapter_path': None, #Dont use adapter for first training
    # 'model_name': 'ufal/robeczech-base',
    'max_length': 96,
    'batch_size': 8,
    'learning_rate': 2e-5,
    'weight_decay': 0.03,
    'num_epochs': 3,
    'dropout_rate': 0.1,
    'warmup_steps': 50,
    'max_grad_norm': 1.0,
    'comments_per_user': 10,
    'early_stopping_patience': 3,
    'random_state': 17,
}

# Try to load preprocessing config
try:
    with open(PROCESSED_DATA_DIR / 'preprocessing_config.json', 'r') as f:
        preproc_config = json.load(f)
        config['random_state'] = preproc_config.get('random_state', 42)
except FileNotFoundError:
    print("Warning: preprocessing_config.json not found, using default random_state")

print("Configuration loaded:")
for key, value in config.items():
    print(f"{key}: {value}")

Configuration loaded:
model_name: distilbert-base-multilingual-cased
adapter_path: None
max_length: 96
batch_size: 8
learning_rate: 2e-05
weight_decay: 0.03
num_epochs: 3
dropout_rate: 0.1
warmup_steps: 50
max_grad_norm: 1.0
comments_per_user: 10
early_stopping_patience: 3
random_state: 42


In [3]:
# # Load preprocessed data splits
# train_df = pd.read_parquet(PROCESSED_DATA_DIR / 'train.parquet')
# val_df = pd.read_parquet(PROCESSED_DATA_DIR / 'val.parquet')
# test_df = pd.read_parquet(PROCESSED_DATA_DIR / 'test.parquet')

# # Load preprocessed small data splits
train_df = pd.read_parquet(PROCESSED_DATA_DIR / 'train.parquet')
val_df = pd.read_parquet(PROCESSED_DATA_DIR / 'val.parquet')
test_df = pd.read_parquet(PROCESSED_DATA_DIR / 'test.parquet')

print("Dataset sizes:")
print(f"Train: {len(train_df)} samples, {train_df['author'].nunique()} authors")
print(f"Val:   {len(val_df)} samples, {val_df['author'].nunique()} authors")
print(f"Test:  {len(test_df)} samples, {test_df['author'].nunique()} authors")

Dataset sizes:
Train: 227740 samples, 8953 authors
Val:   57083 samples, 1919 authors
Test:  62796 samples, 1919 authors


In [4]:
# Create Datasets with regression settings
train_dataset = TrollDataset(
    train_df,
    tokenizer_name=config['model_name'],
    max_length=config['max_length'],
    comments_per_user=config['comments_per_user'],
    label_column='troll',  # or your trolliness score column
    normalize_labels=True  # This will automatically normalize scores to [0,1]
)

val_dataset = TrollDataset(
    val_df,
    tokenizer_name=config['model_name'],
    max_length=config['max_length'],
    comments_per_user=config['comments_per_user'],
    label_column='troll',
    normalize_labels=True
)

test_dataset = TrollDataset(
    test_df,
    tokenizer_name=config['model_name'],
    max_length=config['max_length'],
    comments_per_user=config['comments_per_user'],
    label_column='troll',
    normalize_labels=True
)

# Create dataloaders (unchanged)
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=collate_batch
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_batch
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_batch
)

INFO:src.data_tools.dataset:Using 'text' as text column
INFO:src.data_tools.dataset:Labels are already normalized between 0 and 1
INFO:src.data_tools.dataset:Created 20362 samples from 8953 authors
INFO:src.data_tools.dataset:Using 'text' as text column
INFO:src.data_tools.dataset:Labels are already normalized between 0 and 1
INFO:src.data_tools.dataset:Created 4274 samples from 1919 authors
INFO:src.data_tools.dataset:Using 'text' as text column
INFO:src.data_tools.dataset:Labels are already normalized between 0 and 1
INFO:src.data_tools.dataset:Created 4468 samples from 1919 authors
