In [1]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from datasets import load_dataset, load_from_disk, Dataset
import torch
from torch.utils.data import DataLoader
import numpy as np
import pickle
import yaml
import json
import pandas as pd
from collections import defaultdict

# OmegaConf for configuration management (Hydra-style)
try:
    from omegaconf import OmegaConf
    print("✅ OmegaConf imported successfully")
except ImportError:
    print("❌ OmegaConf not found. Install with: pip install omegaconf")
    raise

root_path = Path.cwd().parent.parent
print("📁 Root path:", root_path)
sys.path.append(str(root_path))

import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger

from src.models.dataset import CSVDataModule, custom_collate_fn
from src.models.models import LitTransformer

print("🚀 All imports successful!")

  from .autonotebook import tqdm as notebook_tqdm


✅ OmegaConf imported successfully
📁 Root path: /Users/takeruito/work/PrfSR
🚀 All imports successful!


# Load Config path

In [2]:
# Load config using OmegaConf (Hydra-style)
from omegaconf import OmegaConf

train_config_path = root_path / "src/models/training_config.yaml"
assert Path(train_config_path).exists(), FileNotFoundError(f"Train config file not found: {train_config_path}")

# Load with OmegaConf for dot notation access
config = OmegaConf.load(train_config_path)

print(f"🔧 Configuration loaded with OmegaConf:")
print(f"Metadata path: {config.metadata_path}")
print(f"Data path: {config.data_path}")
print(f"max_epoch: {config.max_epoch}")
print(f"max_value: {config.max_value}")
print(f"min_n_tokens_in_batch: {config.min_n_tokens_in_batch}")
print(f"test_ratio: {config.test_ratio}")
print(f"val_ratio: {config.val_ratio}")
print(f"num_workers: {config.num_workers}")
print(f"token_embed_dim: {config.token_embed_dim}")
print(f"emb_expansion_factor: {config.emb_expansion_factor}")
print(f"learning_rate: {config.learning_rate}")
print(f"batch_size: {config.batch_size}")
print(f"batching_strategy: {config.batching_strategy}")

print(f"\n🤖 Transformer config:")
print(f"nhead: {config.transformer.nhead}")
print(f"num_encoder_layers: {config.transformer.num_encoder_layers}")
print(f"num_decoder_layers: {config.transformer.num_decoder_layers}")
print(f"dim_feedforward: {config.transformer.dim_feedforward}")
print(f"dropout: {config.transformer.dropout}")


🔧 Configuration loaded with OmegaConf:
Metadata path: /Users/takeruito/work/PrfSR/data/training/superfib_r1_metadata.yaml
Data path: /Users/takeruito/work/PrfSR/data/training/superfib_r1_dataset.csv
max_epoch: 1
max_value: 2000
min_n_tokens_in_batch: 200
test_ratio: 0.01
val_ratio: 0.25
num_workers: 13
token_embed_dim: 16
emb_expansion_factor: 1
learning_rate: 3*10**(-4)
batch_size: 2048
batching_strategy: length_aware_token

🤖 Transformer config:
nhead: 16
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 512
dropout: 0.1


# Load metadata

In [3]:
from omegaconf import OmegaConf
import yaml

# Load metadata using OmegaConf
metadata_path = config.metadata_path
metadata = OmegaConf.load(metadata_path)

print(f"📊 Metadata loaded:")
print(f"max_src_points: {metadata.max_src_points}")
print(f"max_tgt_length: {metadata.max_tgt_length}")
print(f"max_point_dim: {metadata.max_point_dim}")
print(f"src_vocab_size: {len(metadata.src_vocab_list)}")
print(f"tgt_vocab_size: {len(metadata.tgt_vocab_list)}")

# Create inverse vocabularies more efficiently
src_inv_vocab = {token: idx for idx, token in enumerate(metadata.src_vocab_list)}
tgt_inv_vocab = {token: idx for idx, token in enumerate(metadata.tgt_vocab_list)}

print(f"\n🔤 Source vocabulary (first 10): {list(src_inv_vocab.keys())[:10]}")
print(f"🎯 Target vocabulary (first 10): {list(tgt_inv_vocab.keys())[:10]}")

# Add inverse vocabularies to metadata
metadata.src_inv_vocab = src_inv_vocab
metadata.tgt_inv_vocab = tgt_inv_vocab

print(f"\n✅ Inverse vocabularies created:")
print(f"src '[PAD]' index: {metadata.src_inv_vocab['[PAD]']}")
print(f"tgt '[PAD]' index: {metadata.tgt_inv_vocab['[PAD]']}")
print(f"tgt '[BOS]' index: {metadata.tgt_inv_vocab['[BOS]']}")
print(f"tgt '[EOS]' index: {metadata.tgt_inv_vocab['[EOS]']}")

📊 Metadata loaded:
max_src_points: 80
max_tgt_length: 859
max_point_dim: 4
src_vocab_size: 1002
tgt_vocab_size: 15

🔤 Source vocabulary (first 10): ['[PAD]', '0', '1', '10', '100', '1000', '101', '102', '103', '104']
🎯 Target vocabulary (first 10): ['[PAD]', '[BOS]', '[EOS]', '(', ')', ',', '1', '2', '3', '4']

✅ Inverse vocabularies created:
src '[PAD]' index: 0
tgt '[PAD]' index: 0
tgt '[BOS]' index: 1
tgt '[EOS]' index: 2


# Load CSV by chunk

In [4]:
# Initialize WandB
import wandb

# WandB project configuration
wandb_config = {
    "project": "prfsr-chunked-training",
    "name": f"chunk_training_{config.batching_strategy}",
    "config": {
        "batch_size": config.batch_size,
        "learning_rate": eval(config.learning_rate),
        "max_epoch": config.max_epoch,
        "token_embed_dim": config.token_embed_dim,
        "emb_expansion_factor": config.emb_expansion_factor,
        "batching_strategy": config.batching_strategy,
        "min_tokens_per_batch": config.min_n_tokens_in_batch,
        "transformer_nhead": config.transformer.nhead,
        "transformer_num_encoder_layers": config.transformer.num_encoder_layers,
        "transformer_num_decoder_layers": config.transformer.num_decoder_layers,
        "transformer_dim_feedforward": config.transformer.dim_feedforward,
        "transformer_dropout": config.transformer.dropout,
        "max_src_points": metadata.max_src_points,
        "max_tgt_length": metadata.max_tgt_length,
        "src_vocab_size": len(metadata.src_vocab_list),
        "tgt_vocab_size": len(metadata.tgt_vocab_list),
    }
}

# Initialize WandB run
wandb.init(**wandb_config)

print("🔥 WandB initialized successfully!")
print(f"📊 Project: {wandb_config['project']}")
print(f"🏷️ Run name: {wandb_config['name']}")

[34m[1mwandb[0m: Currently logged in as: [33mtakeit[0m ([33mtakeit-Keio University Global Page[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


🔥 WandB initialized successfully!
📊 Project: prfsr-chunked-training
🏷️ Run name: chunk_training_length_aware_token


In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=3,
    mode='min',
    dirpath=root_path / "logs" / "checkpoints",
    filename='model-{epoch:02d}-{val_loss:.2f}'
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    verbose=True
)

# Initialize WandB logger
wandb_logger = WandbLogger(
    project="prfsr-chunked-training",
    name=f"chunk_training_{config.batching_strategy}",
    log_model=True,
    save_dir= root_path / "logs",
)


# Initialize variables
dataset_path = config.data_path
total_samples = 0
losses = []
chunk_metrics = []

print(f"📁 Dataset path: {dataset_path}")
print(f"🧠 Model parameters:")
print(f"  - src_token_num: {metadata.max_src_points}")
print(f"  - tgt_token_num: {len(metadata.tgt_vocab_list)}")
print(f"  - token_embed_dim: {config.token_embed_dim}")
print(f"  - max_src_dim: {metadata.max_point_dim}")
print(f"  - max_tgt_dim: {metadata.max_tgt_length}")

chunk_reader = pd.read_csv(dataset_path, chunksize=100000)
model = LitTransformer(
    src_token_num=metadata.max_src_points,
    tgt_token_num=len(metadata.tgt_vocab_list),
    token_embed_dim=config.token_embed_dim,
    max_src_dim=metadata.max_point_dim,
    max_tgt_dim=metadata.max_tgt_length,
    src_padding_idx=metadata.src_inv_vocab['[PAD]'],
    tgt_padding_idx=metadata.tgt_inv_vocab['[PAD]'],
    emb_expansion_factor=config.emb_expansion_factor,
    t_config=config.transformer,  # Now this will work with dot notation in train.py
    learning_rate=eval(config.learning_rate),
)

trainer = pl.Trainer(
    log_every_n_steps=10,
    max_epochs=config.max_epoch,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, early_stopping],
    logger=wandb_logger,  # Add WandB logger
    enable_progress_bar=True,
    val_check_interval=200,
)

print(f"\n🚀 Starting chunked training...")

for chunk_idx, chunk_df in enumerate(chunk_reader):
    print(f"\n📦 Chunk {chunk_idx+1} - Rows: {len(chunk_df):,}")
    
    # Convert DataFrame to Dataset
    dataset = Dataset.from_pandas(chunk_df)
    dataset = dataset.map(
        lambda x: {
            "source": eval(x["source"]),
            "target": eval(x["target"]),
        },
        batched=False,
        num_proc=1
    )
    dataset = dataset.map(lambda x: {"num_points": len(x["source"])}, batched=False, num_proc=1)
    dataset.sort("num_points", reverse=True)  # Sort by num_points in descending order
    print(dataset[:10])
    
    # Create DataModule
    datamodule = CSVDataModule(
        dataset=dataset,
        batch_size=config.batch_size,
        num_workers=0,  # config.num_workers,
        train_val_split=1 - config.test_ratio,
        seed=42,
        collate_fn=custom_collate_fn,
        batching_strategy=config.batching_strategy,
        min_tokens_per_batch=config.min_n_tokens_in_batch,
        max_batch_size=config.batch_size,
    )
     
    # show batch size
    datamodule.setup()
    for i, batch in enumerate(datamodule.train_dataloader()):
        print("source shape: ", batch[0].shape)
        print("target shape: ", batch[1].shape)
        if i >= 10:  # Only print the first batch
            break
        print()
    
    
    # Training
    print(f"🔥 Training on chunk {chunk_idx+1}...")
    trainer.fit(model, datamodule)
    
    # Log results
    if hasattr(trainer, 'logged_metrics') and trainer.logged_metrics:
        val_loss = trainer.logged_metrics.get('val_loss', 0)
        train_loss = trainer.logged_metrics.get('train_loss', 0)
    elif hasattr(trainer, 'callback_metrics') and trainer.callback_metrics:
        val_loss = trainer.callback_metrics.get('val_loss', 0)
        train_loss = trainer.callback_metrics.get('train_loss', 0)
    else:
        val_loss = 0.0
        train_loss = 0.0
    
    total_samples += len(chunk_df)
    losses.append(float(val_loss))
    
    # Collect chunk metrics
    chunk_metrics.append({
        'chunk_idx': chunk_idx + 1,
        'chunk_size': len(chunk_df),
        'val_loss': float(val_loss),
        'train_loss': float(train_loss),
        'total_samples': total_samples
    })
    
    # Log chunk metrics to WandB
    wandb.log({
        'chunk_idx': chunk_idx + 1,
        'chunk_size': len(chunk_df),
        'chunk_val_loss': float(val_loss),
        'chunk_train_loss': float(train_loss),
        'total_samples_processed': total_samples,
        'cumulative_chunks': chunk_idx + 1
    })
    
    print(f"✅ Chunk {chunk_idx+1} completed:")
    print(f"  - Train Loss: {train_loss:.4f}")
    print(f"  - Val Loss: {val_loss:.4f}")
    print(f"  - Total samples processed: {total_samples:,}")
    
print(f"\n🎉 Training completed!")
print(f"📊 Total samples processed: {total_samples:,}")
print(f"📈 Loss progression: {[f'{loss:.4f}' for loss in losses]}")

# Log final summary to WandB
wandb.log({
    'final_total_samples': total_samples,
    'final_avg_val_loss': np.mean(losses),
    'final_best_val_loss': min(losses),
    'total_chunks_processed': len(losses)
})

# Create summary table for WandB
import pandas as pd
chunk_summary_df = pd.DataFrame(chunk_metrics)
wandb.log({"chunk_summary": wandb.Table(dataframe=chunk_summary_df)})

print(f"\n📊 WandB logging completed!")
print(f"🔗 View your run at: {wandb.run.url}")


  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


📁 Dataset path: /Users/takeruito/work/PrfSR/data/training/superfib_r1_dataset.csv
🧠 Model parameters:
  - src_token_num: 80
  - tgt_token_num: 15
  - token_embed_dim: 16
  - max_src_dim: 4
  - max_tgt_dim: 859

🚀 Starting chunked training...

📦 Chunk 1 - Rows: 100,000


Map: 100%|██████████| 100000/100000 [00:13<00:00, 7551.49 examples/s]
Map: 100%|██████████| 100000/100000 [00:02<00:00, 39503.40 examples/s]


{'source': [[[2, 114, 0, 0], [225, 114, 0, 0], [2, 558, 0, 0], [114, 336, 0, 0], [1, 2, 0, 0], [1, 447, 0, 0], [336, 1, 0, 0], [2, 225, 0, 0]], [[2, 225, 0, 0], [336, 1, 0, 0], [669, 1, 0, 0], [1, 225, 0, 0], [336, 114, 0, 0], [2, 447, 0, 0], [1, 15, 0, 0], [336, 225, 0, 0], [126, 114, 0, 0], [1, 669, 0, 0], [2, 2, 0, 0], [114, 114, 0, 0], [1, 114, 0, 0]], [[2, 336, 0, 0], [225, 558, 0, 0], [2, 114, 0, 0], [114, 891, 0, 0], [1, 1, 0, 0], [447, 2, 0, 0], [225, 225, 0, 0], [3, 669, 0, 0], [114, 447, 0, 0], [114, 1, 0, 0], [225, 2, 0, 0], [336, 1, 0, 0], [669, 336, 0, 0], [558, 1, 0, 0]], [[336, 1, 0, 0], [1, 225, 0, 0], [1, 891, 0, 0], [225, 669, 0, 0], [1, 1, 0, 0], [26, 336, 0, 0], [225, 1, 0, 0], [115, 1, 0, 0], [2, 2, 0, 0], [1, 114, 0, 0], [2, 669, 0, 0], [3, 558, 0, 0], [37, 114, 0, 0], [1, 669, 0, 0], [37, 447, 0, 0], [558, 225, 0, 0], [780, 558, 0, 0], [114, 447, 0, 0]], [[669, 1, 0, 0], [2, 114, 0, 0], [3, 2, 0, 0], [558, 336, 0, 0], [1, 447, 0, 0], [2, 2, 0, 0], [15, 37, 0, 0],

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name          | Type               | Params
-----------------------------------------------------
0 | src_embedding | Embedding          | 1.3 K 
1 | fc1           | Linear             | 4.2 K 
2 | fc2           | Linear             | 4.2 K 
3 | tgt_embedding | Embedding          | 960   
4 | pos_enc       | PositionalEncoding | 0     
5 | transformer   | Transformer        | 1.1 M 
6 | fc_out        | Linear             | 975   
7 | loss_fn       | CrossEntropyLoss   | 0     
-----------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.434     Total estimated model params size (MB)


source shape:  torch.Size([26, 7, 4])
target shape:  torch.Size([26, 166])

source shape:  torch.Size([24, 9, 4])
target shape:  torch.Size([24, 162])

source shape:  torch.Size([29, 9, 4])
target shape:  torch.Size([29, 125])

source shape:  torch.Size([25, 8, 4])
target shape:  torch.Size([25, 195])

source shape:  torch.Size([29, 8, 4])
target shape:  torch.Size([29, 177])

source shape:  torch.Size([33, 9, 4])
target shape:  torch.Size([33, 107])

source shape:  torch.Size([27, 10, 4])
target shape:  torch.Size([27, 195])

source shape:  torch.Size([23, 9, 4])
target shape:  torch.Size([23, 170])

source shape:  torch.Size([26, 9, 4])
target shape:  torch.Size([26, 159])

source shape:  torch.Size([26, 9, 4])
target shape:  torch.Size([26, 177])

source shape:  torch.Size([25, 9, 4])
target shape:  torch.Size([25, 160])
🔥 Training on chunk 1...
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   5%|▌         | 201/3785 [01:10<20:53,  2.86it/s, loss=1.04, v_num=hzq7, train_loss=1.050]

Metric val_loss improved. New best score: 1.015


Epoch 0:  11%|█         | 402/3785 [03:03<25:42,  2.19it/s, loss=0.839, v_num=hzq7, train_loss=0.835, val_loss=1.010]

Metric val_loss improved by 0.287 >= min_delta = 0.0. New best score: 0.728


Epoch 0:  14%|█▍        | 547/3785 [04:16<25:16,  2.14it/s, loss=0.663, v_num=hzq7, train_loss=0.672, val_loss=0.728]