# Đề bài: Sử dụng mạng học sâu để phân loại hình ảnh đất đai từ ảnh vệ tinh
## Model: SegNet
## Data: EuroSat-RGB + BigEarthNet


In [None]:
!rm -rf /kaggle/working/


# 1. Clone repo


In [None]:
# Clone từ nhánh khác (thay 'branch-name' bằng tên nhánh bạn muốn)
# Ví dụ: nhánh 'segnet', 'dev', 'feature/segnet', v.v.
BRANCH_NAME = "segnet"  # Thay đổi tên nhánh ở đây

# Cách 1: Clone trực tiếp từ branch (nếu branch tồn tại)
import os
os.system(f'git clone -b {BRANCH_NAME} https://github.com/yuh-tech/ADCV_2025 /kaggle/working//Final_exam')

# Cách 2: Nếu cách 1 không hoạt động, clone từ main rồi checkout sang branch khác:
# !git clone https://github.com/yuh-tech/ADCV_2025 /kaggle/working//Final_exam
# %cd /kaggle/working//Final_exam
# !git checkout {BRANCH_NAME}


# 2. Cài đặt dependencies


In [None]:
!pip install -r /kaggle/working/Final_exam/requirements.txt


In [None]:
!python /kaggle/working/Final_exam/setup_kaggle.py


# 3. Setup và Import


In [None]:
# ============================================================================
# SETUP AND IMPORTS
# ============================================================================

import sys
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd

# Auto-detect environment
IS_KAGGLE = os.path.exists('/kaggle/input')

if IS_KAGGLE:
    print("Running on Kaggle environment")
    # On Kaggle, repo is cloned to /kaggle/working/Final_exam
    project_root = Path('/kaggle/working/Final_exam')
else:
    print("Running on local environment")
    # On local, notebook is in notebooks/, so go up one level
    project_root = Path.cwd().parent

# Add project root to path
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"Python path updated")


In [None]:
# Import configurations and modules
from config import (
    EUROSAT_PATH, METADATA_PATH, BIGEARTHNET_FOLDERS, 
    REFERENCE_MAPS_FOLDER, CORINE_TO_EUROSAT, CLASS_NAMES, NUM_CLASSES,
    IS_KAGGLE
)
from src.data import (
    EuroSATDataset, BigEarthNetSegmentationDataset,
    get_classification_train_augmentation, get_val_augmentation,
    get_segmentation_train_augmentation
)
from src.models.segnet import SegNet, SegNetWithPretrainedEncoder
from src.utils.visualization import mask_to_rgb, denormalize_image, COLOR_PALETTE

print("Imports successful!")
print(f"\n Paths configuration:")
print(f"  EuroSAT: {EUROSAT_PATH}")
print(f"  Exists: {EUROSAT_PATH.exists()}")
print(f"\n  Metadata: {METADATA_PATH}")
print(f"  Exists: {METADATA_PATH.exists()}")
print(f"\n  Reference Maps: {REFERENCE_MAPS_FOLDER}")
print(f"  Exists: {REFERENCE_MAPS_FOLDER.exists()}")
print(f"\n  BigEarthNet folders: {len(BIGEARTHNET_FOLDERS)} found")
for folder in BIGEARTHNET_FOLDERS:
    print(f"    - {folder} (exists: {folder.exists()})")


## 4. Test SegNet Model Architecture


In [None]:
# Test SegNet model creation
print("="*70)
print("TESTING SEGNET MODEL")
print("="*70)

# Test standard SegNet
print("\n1. Standard SegNet (without pretrained encoder):")
model_standard = SegNet(
    in_channels=3,
    num_classes=NUM_CLASSES,
    base_channels=64
)

# Count parameters
total_params = sum(p.numel() for p in model_standard.parameters())
print(f"  Total parameters: {total_params:,}")

# Test forward pass
test_input = torch.randn(2, 3, 120, 120)
with torch.no_grad():
    output = model_standard(test_input)
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {output.shape}")
print(f"  ✓ Standard SegNet works!")

# Test SegNet with pretrained encoder
print("\n2. SegNet with Pretrained Encoder:")
model_pretrained = SegNetWithPretrainedEncoder(
    encoder_name='resnet50',
    num_classes=NUM_CLASSES,
    encoder_pretrained=True,
    freeze_encoder=False
)

total_params_pretrained = sum(p.numel() for p in model_pretrained.parameters())
print(f"  Total parameters: {total_params_pretrained:,}")

with torch.no_grad():
    output_pretrained = model_pretrained(test_input)
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {output_pretrained.shape}")
print(f"  ✓ SegNet with pretrained encoder works!")

print("\n" + "="*70)
print("✅ SegNet model architecture verified!")
print("="*70)


## 5. Test Data Loading


In [None]:
print("="*70)
print("TESTING EUROSAT DATASET")
print("="*70)

# Create transforms
train_transform = get_classification_train_augmentation(64, strength='light')
val_transform = get_val_augmentation(64)

# Load datasets
try:
    train_dataset = EuroSATDataset(EUROSAT_PATH, 'train', train_transform)
    val_dataset = EuroSATDataset(EUROSAT_PATH, 'val', val_transform)
    test_dataset = EuroSATDataset(EUROSAT_PATH, 'test', val_transform)
    
    print(f"✓ Train samples: {len(train_dataset):,}")
    print(f"✓ Val samples: {len(val_dataset):,}")
    print(f"✓ Test samples: {len(test_dataset):,}")
    print(f"✓ Classes: {train_dataset.classes}")
    
    # Get class distribution
    print("\n Class distribution:")
    train_dataset.get_class_distribution()
    
except Exception as e:
    print(f"✗ Error loading EuroSAT: {e}")
    import traceback
    traceback.print_exc()


In [None]:
print("="*70)
print("LOADING BIGEARTHNET METADATA")
print("="*70)

try:
    metadata_df = pd.read_parquet(METADATA_PATH)
    print(f"✓ Loaded metadata: {len(metadata_df):,} patches")
    print(f"  Columns: {list(metadata_df.columns)}")
    
    # Show split distribution
    print(f"\n Split distribution:")
    print(metadata_df['split'].value_counts())
    
    # For testing, use a subset
    TEST_SUBSET_SIZE = 100  # Adjust based on your needs
    
    train_df = metadata_df[metadata_df['split'] == 'train'].head(TEST_SUBSET_SIZE)
    val_df = metadata_df[metadata_df['split'] == 'validation'].head(50)
    
    print(f"\n Using {len(train_df)} train samples (subset for testing)")
    print(f" Using {len(val_df)} val samples (subset for testing)")
    
except Exception as e:
    print(f"✗ Error loading metadata: {e}")
    import traceback
    traceback.print_exc()


In [None]:
print("="*70)
print("CREATING BIGEARTHNET DATASET")
print("="*70)

# Create transforms
seg_train_transform = get_segmentation_train_augmentation(120, strength='light')
seg_val_transform = get_val_augmentation(120)

try:
    train_dataset_ben = BigEarthNetSegmentationDataset(
        metadata_df=train_df,
        data_folders=BIGEARTHNET_FOLDERS,
        reference_maps_folder=REFERENCE_MAPS_FOLDER,
        corine_to_eurosat_mapping=CORINE_TO_EUROSAT,
        transform=seg_train_transform,
        num_classes=NUM_CLASSES,
        validate_data=True
    )
    
    val_dataset_ben = BigEarthNetSegmentationDataset(
        metadata_df=val_df,
        data_folders=BIGEARTHNET_FOLDERS,
        reference_maps_folder=REFERENCE_MAPS_FOLDER,
        corine_to_eurosat_mapping=CORINE_TO_EUROSAT,
        transform=seg_val_transform,
        num_classes=NUM_CLASSES,
        validate_data=True
    )
    
    print(f"✓ Created BigEarthNet train dataset: {len(train_dataset_ben)} samples")
    print(f"✓ Created BigEarthNet val dataset: {len(val_dataset_ben)} samples")
    
except Exception as e:
    print(f"✗ Error creating BigEarthNet dataset: {e}")
    import traceback
    traceback.print_exc()


In [None]:
from torch.utils.data import DataLoader

print("="*70)
print("TESTING DATALOADER WITH SEGNET")
print("="*70)

# Create DataLoader
train_loader = DataLoader(
    train_dataset_ben,
    batch_size=4,
    shuffle=True,
    num_workers=0,  # Use 0 for debugging
    pin_memory=False
)

# Get one batch
try:
    batch = next(iter(train_loader))
    
    print(f"✓ Batch loaded successfully!")
    print(f"  Images: {batch['image'].shape}")
    print(f"  Masks: {batch['mask'].shape}")
    print(f"  Patch IDs: {batch['patch_id']}")
    
    # Test SegNet forward pass with real data
    print(f"\n Testing SegNet forward pass...")
    model_test = SegNetWithPretrainedEncoder(
        encoder_name='resnet50',
        num_classes=NUM_CLASSES,
        encoder_pretrained=True
    )
    model_test.eval()
    
    images = batch['image']
    with torch.no_grad():
        predictions = model_test(images)
    
    print(f"  Input images: {images.shape}")
    print(f"  Predictions: {predictions.shape}")
    print(f"  Masks: {batch['mask'].shape}")
    print(f"  ✓ SegNet forward pass successful!")
    
except Exception as e:
    print(f"✗ Error: {e}")
    import traceback
    traceback.print_exc()


## 6. Visualize SegNet Predictions


In [None]:
# Visualize SegNet predictions
print("="*70)
print("VISUALIZING SEGNET PREDICTIONS")
print("="*70)

import matplotlib.pyplot as plt
from src.utils.visualization import visualize_segmentation

# Get a sample batch
sample_batch = next(iter(train_loader))
images = sample_batch['image']
masks = sample_batch['mask']

# Create model and predict
model_vis = SegNetWithPretrainedEncoder(
    encoder_name='resnet50',
    num_classes=NUM_CLASSES,
    encoder_pretrained=True
)
model_vis.eval()

with torch.no_grad():
    predictions = model_vis(images)

# Visualize
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
fig.suptitle('SegNet Predictions (Untrained Model)', fontsize=16, fontweight='bold')

for i in range(min(3, len(images))):
    # Original image
    img_np = denormalize_image(images[i])
    axes[i, 0].imshow(img_np)
    axes[i, 0].set_title('Original Image', fontweight='bold')
    axes[i, 0].axis('off')
    
    # Ground truth mask
    mask_np = mask_to_rgb(masks[i].numpy(), COLOR_PALETTE)
    axes[i, 1].imshow(mask_np)
    axes[i, 1].set_title('Ground Truth', fontweight='bold')
    axes[i, 1].axis('off')
    
    # Prediction
    pred_mask = torch.argmax(predictions[i], dim=0).cpu().numpy()
    pred_mask_rgb = mask_to_rgb(pred_mask, COLOR_PALETTE)
    axes[i, 2].imshow(pred_mask_rgb)
    axes[i, 2].set_title('SegNet Prediction', fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

print("✓ Visualization complete!")


## 7. Training SegNet

Bạn có thể chạy training bằng cách:
- **Local**: `python train_stage2_segnet.py --epochs 50`
- **Kaggle**: `!python train_stage2_segnet.py --epochs 50 --batch-size 16`
