In [None]:
import sys
import os
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from lime import lime_image
from skimage.segmentation import mark_boundaries
import yaml
import pytorch_lightning as pl
from densenet201 import DenseNetModel
%matplotlib inline

# 랜덤 시드 고정 함수
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # 랜덤 시드 통합

# 모델, 디바이스, 이미지 변환 함수
def prepare_model_and_transform(model, device='cuda'):
    model.to(device)
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return model, transform

# 이미지 시각화 함수
def display_image(image, title=''):
    plt.imshow(image)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.show()

# Grad-CAM 적용 함수
# def apply_grad_cam(model, image_path, model_name):
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     model, transform = prepare_model_and_transform(model, device)

#     img = Image.open(image_path).convert('RGB')
#     img_resized = img.resize((256, 256))
#     img_tensor = transform(img_resized).unsqueeze(0).to(device)
    
#     target_layer = get_target_layer(model, model_name)
#     cam = GradCAM(model=model, target_layers=[target_layer])
#     grayscale_cam = cam(input_tensor=img_tensor)[0, :]
    
#     img_np = np.array(img_resized) / 255.0
#     visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
    
#     display_image(visualization, 'Grad-CAM')

#     return grayscale_cam  # 반환된 CAM을 insertion, deletion에서 사용

# Grad-CAM 마스크 시각화 확인
def apply_grad_cam(model, image_path, model_name):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, transform = prepare_model_and_transform(model, device)

    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to(device)
    
    target_layer = get_target_layer(model, model_name)
    cam = GradCAM(model=model, target_layers=[target_layer])
    grayscale_cam = cam(input_tensor=img_tensor)[0, :]
    
    img_np = np.array(img_resized) / 255.0
    visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
    
    display_image(visualization, 'Grad-CAM')
    
    # CAM 마스크 시각화
    plt.imshow(grayscale_cam, cmap='jet')
    plt.colorbar()
    plt.title("Grad-CAM Mask")
    plt.show()

    return grayscale_cam  # 반환된 CAM을 insertion, deletion에서 사용


# LIME을 위한 batch_predict 함수
def batch_predict(model, images, transform, device):
    batch = torch.stack([transform(Image.fromarray(image)) for image in images], dim=0).to(device)
    logits = model(batch)
    return torch.softmax(logits, dim=1).detach().cpu().numpy()

# LIME 적용 함수
def apply_lime(model, image_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, transform = prepare_model_and_transform(model, device)

    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to(device)

    outputs = model(img_tensor)
    probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()[0]
    class_idx = np.argmax(probs)
    
    explainer = lime_image.LimeImageExplainer(random_state=42)
    explanation = explainer.explain_instance(np.array(img_resized), lambda x: batch_predict(model, x, transform, device), labels=[class_idx], num_samples=1000)
    
    temp, mask = explanation.get_image_and_mask(label=class_idx, positive_only=True, num_features=5, hide_rest=False)
    img_boundaries = mark_boundaries(temp / 255.0, mask)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_boundaries)
    ax[0].axis('off')
    ax[0].set_title('LIME')
    
    white_image = np.ones_like(temp) * 255
    only_important = np.copy(white_image)
    only_important[mask != 0] = temp[mask != 0]
    
    ax[1].imshow(only_important.astype(np.uint8))
    ax[1].axis('off')
    ax[1].set_title('Important Areas Only')
    plt.show()

# Insertion/Deletion 시각화 함수
def insertion_deletion(model, img_tensor, mask, mode='insertion', steps=20, class_idx=0):
    device = img_tensor.device
    img_tensor = img_tensor.clone().detach().to(device)
    
    scores = []
    images = []
    step = 1.0 / steps
    for i in range(steps + 1):
        if mode == 'insertion':
            modified_img = img_tensor * (1 - mask * (step * i)).to(device)
        elif mode == 'deletion':
            modified_img = img_tensor * (mask * (step * i)).to(device)

        with torch.no_grad():
            output = torch.softmax(model(modified_img), dim=1)[:, class_idx].cpu().item()
        scores.append(output)
        images.append(modified_img.cpu().squeeze().permute(1, 2, 0).numpy())  # 이미지 저장

    return scores, images

# 이미지 정규화 함수
def normalize_image(img):
    img_min, img_max = img.min(), img.max()
    if img_min == img_max:
        return img  # 모든 값이 동일하면 정규화 불필요
    return (img - img_min) / (img_max - img_min)

# Insertion/Deletion 스텝별 시각화 함수
def visualize_insertion_deletion_steps(images, mode, steps=20):
    fig, axs = plt.subplots(2, (steps // 2) + 1, figsize=(20, 5))
    axs = axs.flatten()
    
    for i, img in enumerate(images):
        img = normalize_image(img)
        axs[i].imshow(img)
        axs[i].axis('off')
        axs[i].set_title(f"{mode.capitalize()} Step {i}")
    
    # 남은 빈 플롯들 숨기기
    for j in range(len(images), len(axs)):
        axs[j].axis('off')
    
    plt.tight_layout()
    plt.show()

# Insertion 및 Deletion 수행 및 시각화 함수
def apply_insertion_deletion_visualize(model, image_path, cam_mask=None, steps=20, mode='gradcam'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, transform = prepare_model_and_transform(model, device)

    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to(device)

    outputs = model(img_tensor)
    probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()[0]
    class_idx = np.argmax(probs)

    if mode == 'lime':
        explainer = lime_image.LimeImageExplainer(random_state=42)
        explanation = explainer.explain_instance(np.array(img_resized), lambda x: batch_predict(model, x, transform, device), labels=[class_idx], num_samples=1000)
        _, mask = explanation.get_image_and_mask(label=class_idx, positive_only=True, num_features=5, hide_rest=False)
        mask = torch.tensor(mask).unsqueeze(0).unsqueeze(0).float().to(device)
    else:
        mask = torch.tensor(cam_mask).unsqueeze(0).unsqueeze(0).float().to(device)  # Grad-CAM 마스크 사용

    # Insertion
    insertion_scores, insertion_images = insertion_deletion(model, img_tensor, mask, mode='insertion', steps=steps, class_idx=class_idx)
    
    # Deletion
    deletion_scores, deletion_images = insertion_deletion(model, img_tensor, mask, mode='deletion', steps=steps, class_idx=class_idx)

    # 결과 곡선 시각화
    plt.plot(np.linspace(0, 1, steps + 1), insertion_scores, label='Insertion')
    plt.plot(np.linspace(0, 1, steps + 1), deletion_scores, label='Deletion')
    plt.xlabel('Percentage of Important Region')
    plt.ylabel('Model Output Score')
    plt.title(f'Insertion and Deletion Curves ({mode.capitalize()})')
    plt.legend()
    plt.show()

    # 스텝별 이미지 시각화
    visualize_insertion_deletion_steps(insertion_images, mode='insertion', steps=steps)
    visualize_insertion_deletion_steps(deletion_images, mode='deletion', steps=steps)

# 실행 예시
model_name = 'densenet'
image_path = '/home/xai/son/MLRSNet/Images/parking_lot/parking_lot_00004.jpg'

# 모델 로드
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

model = DenseNetModel.load_from_checkpoint(
    checkpoint_path='/home/xai/son/src/checkpoints/densnet201/densnet201-val_loss=0.10-val_f1=0.78.ckpt',
    num_classes=config['data']['num_classes'],
    learning_rate=config['train']['learning_rate']
)

# Grad-CAM 수행
cam_mask = apply_grad_cam(model, image_path, model_name)

# Insertion 및 Deletion 수행 및 시각화 (Grad-CAM 기반)
apply_insertion_deletion_visualize(model, image_path, cam_mask=cam_mask, steps=10, mode='gradcam')

# Insertion 및 Deletion 수행 및 시각화 (LIME 기반)
apply_insertion_deletion_visualize(model, image_path, steps=10, mode='lime')
