In [None]:
import torch
import torch.nn as nn
from torchvision import models
import os
from Model.salient_bezier_cutmix import salient_and_rect_dual_save_pipeline

# Use GPU if available, otherwise fallback to CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Number of augmented samples to generate
n = 125

# Define data directories
Target_Drive = "./Data/Forest"       # Target class images (to be preserved & combined with masks)
Source_Drive = "./Data/Non_Forest"    # Source class images (to be masked & pasted onto target)
out_img_drive = "./Data/Augmented_Image"  # Output folder for augmented images
out_label_drive = "./Data/Augmented_Label" # Output folder for augmented labels

# Make sure output directories exist
os.makedirs(out_img_drive, exist_ok=True)
os.makedirs(out_label_drive, exist_ok=True)

# Load pretrained EfficientNet-B4 model (ImageNet pretrained)
model = models.efficientnet_b4(weights="IMAGENET1K_V1")

# Access the original first convolution layer
# EfficientNet expects 3-channel input (RGB)
orig_conv = model.features[0][0]   # nn.Conv2d(3, 48, ...)

# Create a new convolution layer that accepts 4 input channels instead of 3
new_conv = nn.Conv2d(
    in_channels=4,  # now supports 4-band input (e.g., RGB + NIR)
    out_channels=orig_conv.out_channels,
    kernel_size=orig_conv.kernel_size,
    stride=orig_conv.stride,
    padding=orig_conv.padding,
    bias=orig_conv.bias is not None
)

# Copy pretrained weights for the first 3 channels (RGB)
# Initialize the 4th channel (NIR) with small random weights
with torch.no_grad():
    new_conv.weight[:, :3, :, :] = orig_conv.weight
    new_conv.weight[:, 3:, :, :] = torch.randn_like(new_conv.weight[:, 3:, :, :]) * 0.01

# Replace the classifier head for binary classification (forest vs non-forest)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)

# Replace the first conv layer with our modified 4-channel version
model.features[0][0] = new_conv

# Load pretrained weights (fine-tuned on forest vs non-forest task)
model.load_state_dict(torch.load("efficientb4_2k_forest_nonforest.pth"))

# Run the salient CutMix augmentation pipeline
# This will create new augmented samples by blending
# target (forest) and source (non-forest) images guided by saliency/mask
salient_and_rect_dual_save_pipeline(
    model=model,
    target_drive=Target_Drive,
    source_drive=Source_Drive,
    out_img_drive=out_img_drive,
    out_label_drive=out_label_drive,
    n=n,
    device=DEVICE,
    prefix = 'forest_as_target',
    inverse=True
)

# Do the inversion of first step
salient_and_rect_dual_save_pipeline(
    model=model,
    target_drive=Source_Drive,
    source_drive=Target_Drive,
    out_img_drive=out_img_drive,
    out_label_drive=out_label_drive,
    n=n,
    device=DEVICE,
    prefix = 'non_forest_as_target',
    inverse =None
)
