In [2]:
import torch
import logging
from pathlib import Path
import yaml
from datetime import datetime
from typing import Dict

from CNN import create_model
from training import Trainer

from data_analysis import analyze_dataset, print_preprocessing_recommendations
from data_preprocessing import DatasetConfig, DataModule, ModelType, DatasetType

from vis_analysis_pipeline import VisualizationManager

In [2]:

def load_config(config_path: str) -> Dict:
    """Load configuration from YAML file with type conversion"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Convert numeric values to proper types
    training = config.get('training', {})
    training['learning_rate'] = float(training.get('learning_rate', 1e-4))
    training['weight_decay'] = float(training.get('weight_decay', 1e-5))
    training['min_lr'] = float(training.get('min_lr', 1e-6))
    training['max_lr'] = float(training.get('max_lr', 1e-3))
    training['warmup_start_lr'] = float(training.get('warmup_start_lr', 1e-7))
    training['plateau_factor'] = float(training.get('plateau_factor', 0.5))
    
    return config


In [3]:
def setup_logging(log_dir: Path):
    """Setup logging configuration"""
    log_dir.mkdir(parents=True, exist_ok=True)
    logging.basicConfig(
        filename=log_dir / 'pipeline.log',
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

In [4]:
# Load configuration
config_path = 'config.yaml'
config = load_config(config_path)

In [5]:
# Setup paths
base_path = Path(config['data_path'])
experiment_name = f"cnn_constellation_classification_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
log_dir = Path(config['log_dir']) / experiment_name

# Setup logging
setup_logging(log_dir)
logging.info("Starting pipeline execution")

In [6]:
# 1. Data Analysis
logging.info("Running data analysis...")

analysis_results = analyze_dataset(
    csv_file=base_path / "train/_classes.csv",
    img_dir=base_path / "train/images"
)
print_preprocessing_recommendations(analysis_results)


=== Dataset Statistics ===
Total number of images: 1641
Total number of labels: 4909
Average labels per image: 2.99

=== Class Imbalance Analysis ===

Class imbalance ratios (relative to most frequent class):
 cassiopeia: 1.00:1
 pleiades: 1.16:1
 ursa_major: 1.21:1
 cygnus: 1.27:1
 lyra: 1.28:1
 moon: 1.36:1
 orion: 1.46:1
 bootes: 1.96:1
 taurus: 1.97:1
 aquila: 1.98:1
 gemini: 2.05:1
 canis_minor: 2.17:1
 leo: 2.31:1
 scorpius: 2.55:1
 canis_major: 2.56:1
 sagittarius: 2.81:1

=== Image Properties Analysis ===

Image Dimensions Summary:
Width  - Mean: 640.0, Min: 640, Max: 640
Height - Mean: 640.0, Min: 640, Max: 640
Aspect Ratio - Mean: 1.00, Min: 1.00, Max: 1.00

=== Preprocessing Recommendations ===

2. Multi-label Specific:
- Use BCEWithLogitsLoss for training
- Consider label correlation in loss function
- Implement proper multi-label evaluation metrics

3. Recommended Augmentation Techniques:
- Random horizontal flips (already implemented)
- Random rotation (±10 degrees)
- Co

In [7]:
# 2. Data Preprocessing
logging.info("Setting up data module...")

data_config = DatasetConfig(base_path)
data_module = DataModule(
    data_config=data_config,
    model_type=ModelType.CNN,
    batch_size=config['training']['batch_size'],
    num_workers=config['training']['num_workers']
)

  return torch.FloatTensor(weights)


In [8]:
# 3. Model Creation
logging.info("Creating model...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_model(
    model_type='cnn',
    num_classes=16,
    pretrained=True,
    backbone=config['model']['backbone'],
    dropout_rate=config['model']['dropout_rate']
)
model = model.to(device)



In [9]:
# 4. Training
logging.info("Starting training...")

trainer = Trainer(
    model=model,
    data_module=data_module,
    config=config['training'],
    device=device,
    experiment_name=experiment_name,
    use_wandb=config['logging']['use_wandb']
)
trainer.train(num_epochs=config['training']['num_epochs'])

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marsive[0m. Use [1m`wandb login --relogin`[0m to force relogin



Starting training...
Log directory: logs\cnn_constellation_classification_20241120_145519
Number of epochs: 10

Epoch 1/10


Epoch 0: 100%|██████████| 51/51 [00:29<00:00,  1.74it/s, loss=1.93]



Training metrics:
Loss: 1.8207
Mean AP: 0.4146


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.28s/it]



Validation metrics:
Loss: 0.8932
Mean AP: 0.3577
Previous best MAP: 0.0000
Learning rate: 0.000333

Improvement detected! Previous best: 0.0000, Current: 0.3577

Attempting to save checkpoint:
Current MAP: 0.35767286959241723
Best MAP so far: 0.35767286959241723
Saving checkpoint to: logs\cnn_constellation_classification_20241120_145519\best_model.pt
Successfully saved checkpoint!

Epoch 2/10


Epoch 1: 100%|██████████| 51/51 [00:28<00:00,  1.77it/s, loss=0.48] 



Training metrics:
Loss: 0.8490
Mean AP: 0.6499


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.28s/it]



Validation metrics:
Loss: 0.3408
Mean AP: 0.7857
Previous best MAP: 0.3577
Learning rate: 0.000667

Improvement detected! Previous best: 0.3577, Current: 0.7857

Attempting to save checkpoint:
Current MAP: 0.7856525490591131
Best MAP so far: 0.7856525490591131
Saving checkpoint to: logs\cnn_constellation_classification_20241120_145519\best_model.pt
Successfully saved checkpoint!

Epoch 3/10


Epoch 2: 100%|██████████| 51/51 [00:28<00:00,  1.78it/s, loss=0.217]



Training metrics:
Loss: 0.3055
Mean AP: 0.7700


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.27s/it]



Validation metrics:
Loss: 0.2948
Mean AP: 0.6282
Previous best MAP: 0.7857
Learning rate: 0.001000

No improvement. Previous best: 0.7857, Current: 0.6282

Epoch 4/10


Epoch 3: 100%|██████████| 51/51 [00:28<00:00,  1.78it/s, loss=0.16]  



Training metrics:
Loss: 0.1996
Mean AP: 0.8124


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.28s/it]



Validation metrics:
Loss: 0.1475
Mean AP: 0.7124
Previous best MAP: 0.7857
Learning rate: 0.000501

No improvement. Previous best: 0.7857, Current: 0.7124

Epoch 5/10


Epoch 4: 100%|██████████| 51/51 [00:29<00:00,  1.74it/s, loss=0.126] 



Training metrics:
Loss: 0.1199
Mean AP: 0.8503


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.30s/it]



Validation metrics:
Loss: 0.0997
Mean AP: 0.7533
Previous best MAP: 0.7857
Learning rate: 0.001000

No improvement. Previous best: 0.7857, Current: 0.7533

Epoch 6/10


Epoch 5: 100%|██████████| 51/51 [00:28<00:00,  1.77it/s, loss=0.183] 



Training metrics:
Loss: 0.1084
Mean AP: 0.8865


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.31s/it]



Validation metrics:
Loss: 0.1167
Mean AP: 0.7775
Previous best MAP: 0.7857
Learning rate: 0.000501

No improvement. Previous best: 0.7857, Current: 0.7775

Epoch 7/10


Epoch 6: 100%|██████████| 51/51 [00:29<00:00,  1.75it/s, loss=0.0778]



Training metrics:
Loss: 0.0813
Mean AP: 0.9037


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.27s/it]



Validation metrics:
Loss: 0.0648
Mean AP: 0.8275
Previous best MAP: 0.7857
Learning rate: 0.001000

Improvement detected! Previous best: 0.7857, Current: 0.8275

Attempting to save checkpoint:
Current MAP: 0.8275072900032922
Best MAP so far: 0.8275072900032922
Saving checkpoint to: logs\cnn_constellation_classification_20241120_145519\best_model.pt
Successfully saved checkpoint!

Epoch 8/10


Epoch 7: 100%|██████████| 51/51 [00:28<00:00,  1.77it/s, loss=0.109] 



Training metrics:
Loss: 0.1019
Mean AP: 0.9010


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.30s/it]



Validation metrics:
Loss: 0.2500
Mean AP: 0.5236
Previous best MAP: 0.8275
Learning rate: 0.000501

No improvement. Previous best: 0.8275, Current: 0.5236

Epoch 9/10


Epoch 8: 100%|██████████| 51/51 [00:28<00:00,  1.77it/s, loss=0.0796]



Training metrics:
Loss: 0.0812
Mean AP: 0.9125


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.29s/it]



Validation metrics:
Loss: 0.0547
Mean AP: 0.8520
Previous best MAP: 0.8275
Learning rate: 0.001000

Improvement detected! Previous best: 0.8275, Current: 0.8520

Attempting to save checkpoint:
Current MAP: 0.8519546669866499
Best MAP so far: 0.8519546669866499
Saving checkpoint to: logs\cnn_constellation_classification_20241120_145519\best_model.pt
Successfully saved checkpoint!

Epoch 10/10


Epoch 9: 100%|██████████| 51/51 [00:28<00:00,  1.77it/s, loss=0.104] 



Training metrics:
Loss: 0.0706
Mean AP: 0.9291


Validation: 100%|██████████| 15/15 [00:19<00:00,  1.30s/it]


Validation metrics:
Loss: 0.0944
Mean AP: 0.8298
Previous best MAP: 0.8520
Learning rate: 0.000501

No improvement. Previous best: 0.8520, Current: 0.8298





In [10]:

# 5. Visualization and Analysis
logging.info("Running visualization pipeline...")

viz_manager = VisualizationManager(
    save_dir=log_dir / 'visualizations',
    class_names=data_module.datasets[DatasetType.TRAIN].class_columns
)

# Get validation predictions for visualization
val_loader = data_module.get_dataloader(DatasetType.VALID)
all_preds = []
all_targets = []
all_images = []

model.eval()
with torch.no_grad():
    for images, targets in val_loader:
        images = images.to(device)
        outputs = model(images)
        predictions = torch.sigmoid(outputs)
        all_preds.append(predictions.cpu())
        all_targets.append(targets)
        all_images.append(images.cpu())

all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)
all_images = torch.cat(all_images)

# Create visualizations
viz_manager.visualize_predictions(all_images[:16], all_preds[:16], all_targets[:16])

In [11]:
viz_manager.plot_roc_curves(all_preds.numpy(), all_targets.numpy())

In [12]:
viz_manager.plot_precision_recall_curves(all_preds.numpy(), all_targets.numpy())
logging.info("Pipeline execution completed successfully")

In [None]:
from cnn_inference import CNNPredictor
import cv2

# Initialize predictor
predictor = CNNPredictor(
    model_path="logs/cnn_constellation_classification_20241120_145519/best_model.pt",
    config_path="config.yaml"
)

# Test single image prediction
image_path = r"data\constellation_dataset_1\test\images\2022-01-05-00-00-00-s_png_jpg.rf.098ebe8a5c09f983736111049dfefc1d.jpg"
predictions, visualization = predictor.predict(
    image_path,
    conf_thresh=0.5,
    return_visualization=True
)

# Save results
output_dir = Path("inference_results/cnn")
output_dir.mkdir(parents=True, exist_ok=True)

cv2.imwrite(
    str(output_dir / "prediction.jpg"),
    visualization
)

  checkpoint = torch.load(model_path, map_location=self.device)


True