In [4]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from datasets import load_dataset, load_from_disk, Dataset
import torch
from torch.utils.data import DataLoader
import numpy as np
import pickle
import yaml
import json
import pandas as pd
from collections import defaultdict

# OmegaConf for configuration management (Hydra-style)
try:
    from omegaconf import OmegaConf
    print("✅ OmegaConf imported successfully")
except ImportError:
    print("❌ OmegaConf not found. Install with: pip install omegaconf")
    raise

root_path = Path.cwd().parent
print("📁 Root path:", root_path)
sys.path.append(str(root_path))

import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger

from src.model_meta.dataset import CSVDataModule, custom_collate_fn
from src.model_meta.train import LitTransformer

print("🚀 All imports successful!")

✅ OmegaConf imported successfully
📁 Root path: /Users/takeruito/work/PrfSR
🚀 All imports successful!


# Load Config path

In [5]:
# Load config using OmegaConf (Hydra-style)
from omegaconf import OmegaConf

train_config_path = root_path / "src/model_meta/training_config.yaml"
assert Path(train_config_path).exists(), FileNotFoundError(f"Train config file not found: {train_config_path}")

# Load with OmegaConf for dot notation access
config = OmegaConf.load(train_config_path)

print(f"🔧 Configuration loaded with OmegaConf:")
print(f"Metadata path: {config.metadata_path}")
print(f"Data path: {config.data_path}")
print(f"max_epoch: {config.max_epoch}")
print(f"max_value: {config.max_value}")
print(f"min_n_tokens_in_batch: {config.min_n_tokens_in_batch}")
print(f"test_ratio: {config.test_ratio}")
print(f"val_ratio: {config.val_ratio}")
print(f"num_workers: {config.num_workers}")
print(f"token_embed_dim: {config.token_embed_dim}")
print(f"emb_expansion_factor: {config.emb_expansion_factor}")
print(f"learning_rate: {config.learning_rate}")
print(f"batch_size: {config.batch_size}")
print(f"batching_strategy: {config.batching_strategy}")

print(f"\n🤖 Transformer config:")
print(f"nhead: {config.transformer.nhead}")
print(f"num_encoder_layers: {config.transformer.num_encoder_layers}")
print(f"num_decoder_layers: {config.transformer.num_decoder_layers}")
print(f"dim_feedforward: {config.transformer.dim_feedforward}")
print(f"dropout: {config.transformer.dropout}")


🔧 Configuration loaded with OmegaConf:
Metadata path: /Users/takeruito/work/PrfSR/data/training/superfib_r1_metadata.yaml
Data path: /Users/takeruito/work/PrfSR/data/training/superfib_r1_dataset.csv
max_epoch: 1000
max_value: 2000
min_n_tokens_in_batch: 2000
test_ratio: 0.1
val_ratio: 0.25
num_workers: 13
token_embed_dim: 16
emb_expansion_factor: 1
learning_rate: 3*10**(-4)
batch_size: 64
batching_strategy: length_aware_token

🤖 Transformer config:
nhead: 16
num_encoder_layers: 4
num_decoder_layers: 6
dim_feedforward: 512
dropout: 0.1


# Load metadata

In [6]:
from omegaconf import OmegaConf
import yaml

# Load metadata using OmegaConf
metadata_path = config.metadata_path
metadata = OmegaConf.load(metadata_path)

print(f"📊 Metadata loaded:")
print(f"max_src_points: {metadata.max_src_points}")
print(f"max_tgt_length: {metadata.max_tgt_length}")
print(f"max_point_dim: {metadata.max_point_dim}")
print(f"src_vocab_size: {len(metadata.src_vocab_list)}")
print(f"tgt_vocab_size: {len(metadata.tgt_vocab_list)}")

# Create inverse vocabularies more efficiently
src_inv_vocab = {token: idx for idx, token in enumerate(metadata.src_vocab_list)}
tgt_inv_vocab = {token: idx for idx, token in enumerate(metadata.tgt_vocab_list)}

print(f"\n🔤 Source vocabulary (first 10): {list(src_inv_vocab.keys())[:10]}")
print(f"🎯 Target vocabulary (first 10): {list(tgt_inv_vocab.keys())[:10]}")

# Add inverse vocabularies to metadata
metadata.src_inv_vocab = src_inv_vocab
metadata.tgt_inv_vocab = tgt_inv_vocab

print(f"\n✅ Inverse vocabularies created:")
print(f"src '[PAD]' index: {metadata.src_inv_vocab['[PAD]']}")
print(f"tgt '[PAD]' index: {metadata.tgt_inv_vocab['[PAD]']}")
print(f"tgt '[BOS]' index: {metadata.tgt_inv_vocab['[BOS]']}")
print(f"tgt '[EOS]' index: {metadata.tgt_inv_vocab['[EOS]']}")

📊 Metadata loaded:
max_src_points: 80
max_tgt_length: 859
max_point_dim: 4
src_vocab_size: 1002
tgt_vocab_size: 15

🔤 Source vocabulary (first 10): ['[PAD]', '0', '1', '10', '100', '1000', '101', '102', '103', '104']
🎯 Target vocabulary (first 10): ['[PAD]', '[BOS]', '[EOS]', '(', ')', ',', '1', '2', '3', '4']

✅ Inverse vocabularies created:
src '[PAD]' index: 0
tgt '[PAD]' index: 0
tgt '[BOS]' index: 1
tgt '[EOS]' index: 2


# Load CSV by chunk

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=3,
    mode='min',
    dirpath='./checkpoints',
    filename='model-{epoch:02d}-{val_loss:.2f}'
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    verbose=True
)

# Initialize variables
dataset_path = config.data_path
total_samples = 0
losses = []

print(f"📁 Dataset path: {dataset_path}")
print(f"🧠 Model parameters:")
print(f"  - src_token_num: {metadata.max_src_points}")
print(f"  - tgt_token_num: {len(metadata.tgt_vocab_list)}")
print(f"  - token_embed_dim: {config.token_embed_dim}")
print(f"  - max_src_dim: {metadata.max_point_dim}")
print(f"  - max_tgt_dim: {metadata.max_tgt_length}")

chunk_reader = pd.read_csv(dataset_path, chunksize=10000)
model = LitTransformer(
    src_token_num=metadata.max_src_points,
    tgt_token_num=len(metadata.tgt_vocab_list),
    token_embed_dim=config.token_embed_dim,
    max_src_dim=metadata.max_point_dim,
    max_tgt_dim=metadata.max_tgt_length,
    src_padding_idx=metadata.src_inv_vocab['[PAD]'],
    tgt_padding_idx=metadata.tgt_inv_vocab['[PAD]'],
    emb_expansion_factor=config.emb_expansion_factor,
    t_config=config.transformer,  # Now this will work with dot notation in train.py
    learning_rate=eval(config.learning_rate),
)

trainer = pl.Trainer(
    max_epochs=config.max_epoch,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, early_stopping],
    logger=False,
    enable_progress_bar=True
)

print(f"\n🚀 Starting chunked training...")

for chunk_idx, chunk_df in enumerate(chunk_reader):
    print(f"\n📦 Chunk {chunk_idx+1} - Rows: {len(chunk_df):,}")
    
    # Convert DataFrame to Dataset
    dataset = Dataset.from_pandas(chunk_df)
    dataset = dataset.map(
        lambda x: {
            "source": eval(x["source"]),
            "target": eval(x["target"]),
        },
        batched=False,
        num_proc=1
    )
    
    # Create DataModule
    datamodule = CSVDataModule(
        dataset=dataset,
        batch_size=config.batch_size,
        num_workers=0,  # config.num_workers,
        train_val_split=1 - config.test_ratio,
        seed=42,
        collate_fn=custom_collate_fn,
        batching_strategy=config.batching_strategy,
        min_tokens_per_batch=config.min_n_tokens_in_batch,
        max_batch_size=config.batch_size,
    )
    
    # Training
    print(f"🔥 Training on chunk {chunk_idx+1}...")
    trainer.fit(model, datamodule)
    
    # Log results
    if hasattr(trainer, 'logged_metrics') and trainer.logged_metrics:
        val_loss = trainer.logged_metrics.get('val_loss', 0)
    elif hasattr(trainer, 'callback_metrics') and trainer.callback_metrics:
        val_loss = trainer.callback_metrics.get('val_loss', 0)
    else:
        val_loss = 0.0
    
    total_samples += len(chunk_df)
    losses.append(float(val_loss))
    
    print(f"✅ Chunk {chunk_idx+1} completed:")
    print(f"  - Val Loss: {val_loss:.4f}")
    print(f"  - Total samples processed: {total_samples:,}")
    
    # Break after 3 chunks for testing
    if chunk_idx >= 2:
        print("🛑 Stopping after 3 chunks for testing...")
        break

print(f"\n🎉 Training completed!")
print(f"📊 Total samples processed: {total_samples:,}")
print(f"📈 Loss progression: {[f'{loss:.4f}' for loss in losses]}")


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


📁 Dataset path: /Users/takeruito/work/PrfSR/data/training/superfib_r1_dataset.csv
🧠 Model parameters:
  - src_token_num: 80
  - tgt_token_num: 859
  - token_embed_dim: 16
  - max_src_dim: 4
  - max_tgt_dim: 15

🚀 Starting chunked training...

📦 Chunk 1 - Rows: 10,000


Map: 100%|██████████| 10000/10000 [00:01<00:00, 7945.24 examples/s]

  | Name          | Type               | Params
-----------------------------------------------------
0 | src_embedding | Embedding          | 1.3 K 
1 | fc1           | Linear             | 4.2 K 
2 | fc2           | Linear             | 4.2 K 
3 | tgt_embedding | Embedding          | 55.0 K
4 | pos_enc       | PositionalEncoding | 0     
5 | transformer   | Transformer        | 930 K 
6 | fc_out        | Linear             | 55.8 K
7 | loss_fn       | CrossEntropyLoss   | 0     
-----------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.205     Total estimated model params size (MB)


🔥 Training on chunk 1...
Train dataset: 9000 samples
Validation dataset: 1000 samples
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


RuntimeError: The size of tensor a (138) must match the size of tensor b (15) at non-singleton dimension 0