In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import json
from tabulate import tabulate

from models.Regression_models import (get_thz_model, 
    get_loss_function, 
    advanced_train_model,
    evaluate_model, AdaptivePrecisionLoss)
from models.utils import identify_device, display_model
from models.regression_utils import get_train_val_loaders, denormalize_material_params
from torch.utils.data import DataLoader

In [None]:
device = identify_device()

In [None]:
file_path = "regression_data/train_3_layer_512_nonoise_n1to8.pt"

train_loader, val_loader, num_samples = get_train_val_loaders(
    dataset_path=file_path,
    batch_size=128,
    val_split=0.1
)


print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

In [None]:
# Test the ultra model on 32 samples
model_type = 'ultra'  # Much larger model
model = get_thz_model(model_type)
loss_type = 'weighted'

total_params = sum(p.numel() for p in model.parameters())
print(f"Ultra model parameters: {total_params:,}")  # Should be ~10M

In [None]:
print("=== Testing Overfitting Capability ===")

# Create small dataset for overfitting test
small_dataset = []
count = 0
for batch in train_loader:
    for i in range(len(batch[0])):
        # Keep correct shape: [channels, length]
        small_dataset.append((batch[0][i], batch[1][i]))
        count += 1
        if count >= 32:
            break
    if count >= 32:
        break

# Create small dataloader
small_loader = DataLoader(small_dataset, batch_size=8, shuffle=True)

# Test model on small dataset
test_model = get_thz_model(model_type)
test_loss_fn = get_loss_function(loss_type).to(device)

# Test overfitting with higher learning rate
test_model, test_history = advanced_train_model(
    model=model,
    train_loader=small_loader,  # 32 samples
    num_epochs=200,
    initial_lr=5e-3,  # Higher LR
    loss_fn=AdaptivePrecisionLoss(),  # New adaptive loss
    patience=300,
    save_best=False
)

print(f"Final training loss on 32 samples: {test_history['train_loss'][-1]:.6f}")
if test_history['train_loss'][-1] < 0.001:
    print("✅ Model can overfit! Architecture is capable.")
else:
    print("⚠️  Model struggling to overfit. May need architecture changes.")

In [None]:
metrics, pred_denorm, target_denorm = evaluate_model(test_model, small_loader, denormalize_material_params, device=device)

len(metrics)

In [None]:
print("\n=== Starting Full Training ===")

# Fresh model for full training
model = get_thz_model(model_type, input_channels=1, output_dim=9)
loss_fn = get_loss_function(loss_type).to(device)

# Train the model
trained_model, history = advanced_train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=500,  # Much longer training
    initial_lr=1e-3,
    loss_fn=loss_fn,
    denormalize_fn=denormalize_material_params,  # For real-world metrics
    device=device,
    patience=50,    # Early stopping patience
    min_delta=1e-6, # Minimum improvement threshold
    save_best=True,
    model_save_path="best_thz_advanced_model.pt"
)

print("Training completed!")

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(history['train_loss'], label='Training Loss', alpha=0.8)
ax1.plot(history['val_loss'], label='Validation Loss', alpha=0.8)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.set_yscale('log')  # Log scale for better visualization
ax1.grid(True, alpha=0.3)

# Learning rate schedule
ax2.plot(history['learning_rate'], color='orange')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Best validation loss: {min(history['val_loss']):.6f}")

In [None]:
print("\n=== Comprehensive Evaluation ===")

# Evaluate on validation set
metrics, pred_denorm, target_denorm = evaluate_model(
    trained_model, val_loader, denormalize_material_params, device
)

# Create metrics table
table_data = []
target_accuracy = {
    'n1': 0.01, 'n2': 0.01, 'n3': 0.01,
    'k1': 1e-5, 'k2': 1e-5, 'k3': 1e-5,
    'd1': 1e-6, 'd2': 1e-6, 'd3': 1e-6
}

for param, metric in metrics.items():
    target_acc = target_accuracy[param]
    rmse_ratio = metric['rmse'] / target_acc
    meets_target = "✅" if metric['rmse'] <= target_acc else "❌"
    
    table_data.append([
        param,
        f"{metric['rmse']:.2e}",
        f"{metric['mae']:.2e}",
        f"{metric['r2']:.4f}",
        f"{metric['max_error']:.2e}",
        f"{target_acc:.0e}",
        f"{rmse_ratio:.1f}x",
        meets_target
    ])

headers = ['Param', 'RMSE', 'MAE', 'R²', 'Max Error', 'Target', 'Ratio', 'Meets Target']
print(tabulate(table_data, headers=headers, tablefmt='grid'))

# Overall summary
total_meeting_target = sum(1 for param, metric in metrics.items() 
                          if metric['rmse'] <= target_accuracy[param])
print(f"\nParameters meeting target accuracy: {total_meeting_target}/9")

In [None]:
# Get a batch of predictions for visualization
trained_model.eval()
with torch.no_grad():
    sample_batch = next(iter(val_loader))
    inputs, targets = sample_batch
    inputs, targets = inputs.to(device), targets.to(device)
    predictions = trained_model(inputs)
    
    # Denormalize for plotting
    pred_denorm_batch = denormalize_material_params(predictions.cpu())
    target_denorm_batch = denormalize_material_params(targets.cpu())

# Plot predictions vs targets for first few samples
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
param_names = ['n1', 'k1', 'd1', 'n2', 'k2', 'd2', 'n3', 'k3', 'd3']

for i, (ax, param) in enumerate(zip(axes.flat, param_names)):
    pred_vals = pred_denorm_batch[:10, i].numpy()  # First 10 samples
    target_vals = target_denorm_batch[:10, i].numpy()
    
    ax.scatter(target_vals, pred_vals, alpha=0.7)
    
    # Perfect prediction line
    min_val, max_val = min(target_vals.min(), pred_vals.min()), max(target_vals.max(), pred_vals.max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='Perfect')
    
    ax.set_xlabel(f'True {param}')
    ax.set_ylabel(f'Predicted {param}')
    ax.set_title(f'{param} (R² = {metrics[param]["r2"]:.3f})')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Save training history
torch.save(history, 'training_history.pt')

# Save final metrics
metrics_serializable = {}
for param, metric in metrics.items():
    metrics_serializable[param] = {k: float(v) for k, v in metric.items()}

with open('final_metrics.json', 'w') as f:
    json.dump(metrics_serializable, f, indent=2)

print("Results saved!")
print(f"Model saved as: best_thz_advanced_model.pt")
print(f"Training history saved as: training_history.pt") 
print(f"Metrics saved as: final_metrics.json")

# Model summary
print(f"\n=== Final Model Summary ===")
print(f"Model type: {model_type}")
print(f"Total parameters: {total_params:,}")
print(f"Loss function: {loss_type}")
print(f"Best validation loss: {min(history['val_loss']):.6f}")
print(f"Parameters meeting target: {total_meeting_target}/9")

In [None]:
# If results aren't good enough, try these variations:

print("\n=== Alternative Configurations to Try ===")
print("1. MultiHead model:")
print("   model = get_thz_model('multihead')")
print("   loss_fn = get_loss_function('multitask')")

print("\n2. Higher learning rate:")
print("   initial_lr=5e-3")

print("\n3. Longer training:")
print("   num_epochs=1000, patience=100")

print("\n4. Different loss weighting:")
print("   Edit WeightedParameterLoss weights in regression_models.py")

print("\nIf overfitting test failed, the architecture needs to be even larger!")