In [None]:
import sys
import os
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from lime import lime_image
from skimage.segmentation import mark_boundaries
import yaml
import pytorch_lightning as pl
from densenet201 import DenseNetModel
from numpy import trapz
%matplotlib inline

# 랜덤 시드 고정 함수
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # 랜덤 시드 통합
# XAI 관련 함수 정의
def get_target_layer(model, model_name):
    if model_name == 'resnet':
        return model.model.layer4[-1]
    elif model_name == 'googlenet':
        return model.model.inception5b
    elif model_name == 'efficientnet':
        return model.model.features[-1]
    elif model_name == 'densenet':
        return model.model.features[-1]
    else:
        raise ValueError(f"Unsupported model name: {model_name}")
# 모델, 디바이스, 이미지 변환 함수
def prepare_model_and_transform(model, device='cuda'):
    model.to(device)
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return model, transform

# 이미지 시각화 함수
def display_image(image, title=''):
    plt.imshow(image)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.show()

# Grad-CAM 마스크 시각화 확인
def apply_grad_cam(model, image_path, model_name):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, transform = prepare_model_and_transform(model, device)

    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to(device)
    
    target_layer = get_target_layer(model, model_name)
    cam = GradCAM(model=model, target_layers=[target_layer])
    grayscale_cam = cam(input_tensor=img_tensor)[0, :]
    
    img_np = np.array(img_resized) / 255.0
    visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
    
    display_image(visualization, 'Grad-CAM')
    
    # CAM 마스크 시각화
    plt.imshow(grayscale_cam, cmap='jet')
    plt.colorbar()
    plt.title("Grad-CAM Mask")
    plt.show()

    return grayscale_cam


# LIME을 위한 batch_predict 함수
def batch_predict(model, images, transform, device):
    batch = torch.stack([transform(Image.fromarray(image)) for image in images], dim=0).to(device)
    logits = model(batch)
    return torch.softmax(logits, dim=1).detach().cpu().numpy()

# LIME 적용 함수
def apply_lime(model, image_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, transform = prepare_model_and_transform(model, device)

    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to(device)

    outputs = model(img_tensor)
    probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()[0]
    class_idx = np.argmax(probs)
    
    explainer = lime_image.LimeImageExplainer(random_state=42)
    explanation = explainer.explain_instance(np.array(img_resized), lambda x: batch_predict(model, x, transform, device), labels=[class_idx], num_samples=1000)
    
    temp, mask = explanation.get_image_and_mask(label=class_idx, positive_only=True, num_features=5, hide_rest=False)
    img_boundaries = mark_boundaries(temp / 255.0, mask)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_boundaries)
    ax[0].axis('off')
    ax[0].set_title('LIME')
    
    white_image = np.ones_like(temp) * 255
    only_important = np.copy(white_image)
    only_important[mask != 0] = temp[mask != 0]
    
    ax[1].imshow(only_important.astype(np.uint8))
    ax[1].axis('off')
    ax[1].set_title('Important Areas Only')
    plt.show()

def insertion_deletion(model, img_tensor, mask, mode='insertion', steps=20, class_idx=0):
    device = img_tensor.device
    img_tensor = img_tensor.clone().detach().to(device)
    
    scores = []
    images = []
    step = 1.0 / steps
    img_original = img_tensor.clone().detach().cpu().squeeze().permute(1, 2, 0).numpy()
    
    # 마스크를 (256, 256, 1)로 변환한 후, 3채널로 확장
    mask_3d = np.repeat(mask.squeeze().cpu().numpy()[:, :, np.newaxis], 3, axis=2)

    for i in range(steps + 1):
        if mode == 'insertion':
            # Insertion: 중요 영역을 점차적으로 추가, 나머지 영역을 흰색으로 만듦
            modified_img = img_original * (mask_3d * (step * i))
        elif mode == 'deletion':
            # Deletion: 중요 영역을 점차적으로 제거, 나머지 영역을 흰색으로 만듦
            modified_img = img_original * (1 - mask_3d * (step * i))

        with torch.no_grad():
            modified_tensor = torch.tensor(modified_img).unsqueeze(0).permute(0, 3, 1, 2).to(device)
            output = torch.softmax(model(modified_tensor), dim=1)[:, class_idx].cpu().item()
        
        # 스텝별 출력 추가
        print(f'Step {i} - {mode.capitalize()} score: {output}')
        
        scores.append(output)
        images.append(modified_img)

    return scores, images

# 이미지 정규화 함수
def normalize_image(img):
    img_min, img_max = img.min(), img.max()
    if img_min == img_max:
        return img
    return (img - img_min) / (img_max - img_min)

# Insertion/Deletion 스텝별 시각화 함수
def visualize_insertion_deletion_steps(images, mode, steps=20):
    fig, axs = plt.subplots(2, (steps // 2) + 1, figsize=(20, 5))
    axs = axs.flatten()
    
    for i, img in enumerate(images):
        img = normalize_image(img)
        axs[i].imshow(img)
        axs[i].axis('off')
        axs[i].set_title(f"{mode.capitalize()} Step {i}")
    
    # 남은 빈 플롯들 숨기기
    for j in range(len(images), len(axs)):
        axs[j].axis('off')
    
    plt.tight_layout()
    plt.show()

# 결과 곡선 시각화
def plot_insertion_deletion_curves(insertion_scores, deletion_scores, steps, mode):
    plt.plot(range(steps + 1), insertion_scores, label='Insertion', marker='o')
    plt.plot(range(steps + 1), deletion_scores, label='Deletion', marker='o')
    plt.xlabel('Step')
    plt.ylabel('Model Output Score')
    plt.ylim(0, 1)
    plt.title(f'Insertion and Deletion Curves ({mode.capitalize()})')
    plt.legend()
    plt.show()

# AUPC 계산 함수
def calculate_aupc(scores, steps):
    x = np.linspace(0, 1, steps + 1)
    return trapz(scores, x)

# Insertion 및 Deletion 수행 및 AUPC 계산 적용
def apply_insertion_deletion_visualize(model, image_path, cam_mask=None, steps=20, mode='gradcam', class_labels=None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, transform = prepare_model_and_transform(model, device)

    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to(device)

    outputs = model(img_tensor)
    probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()[0]
    class_idx = np.argmax(probs)

    if class_labels is not None:
        class_name = class_labels[class_idx]
        print(f"Predicted class: {class_name} (class_idx: {class_idx})")
    else:
        print(f"Predicted class_idx: {class_idx}")

    if mode == 'lime':
        explainer = lime_image.LimeImageExplainer(random_state=42)
        explanation = explainer.explain_instance(np.array(img_resized), lambda x: batch_predict(model, x, transform, device), labels=[class_idx], num_samples=1000)
        _, mask = explanation.get_image_and_mask(label=class_idx, positive_only=True, num_features=5, hide_rest=False)
        mask = torch.tensor(mask).unsqueeze(0).unsqueeze(0).float().to(device)
    else:
        mask = torch.tensor(cam_mask).unsqueeze(0).unsqueeze(0).float().to(device)

    # Insertion
    print(f'=== Insertion ({mode.capitalize()}) ===')
    insertion_scores, insertion_images = insertion_deletion(model, img_tensor, mask, mode='insertion', steps=steps, class_idx=class_idx)
    
    # Deletion
    print(f'=== Deletion ({mode.capitalize()}) ===')
    deletion_scores, deletion_images = insertion_deletion(model, img_tensor, mask, mode='deletion', steps=steps, class_idx=class_idx)

    # AUPC 계산
    insertion_aupc = calculate_aupc(insertion_scores, steps)
    deletion_aupc = calculate_aupc(deletion_scores, steps)
    total_aupc = (insertion_aupc + deletion_aupc) / 2

    print(f'AUPC: {total_aupc:.3f} (insertion: {insertion_aupc:.3f}, deletion: {deletion_aupc:.3f})')

    # 결과 곡선 시각화
    plot_insertion_deletion_curves(insertion_scores, deletion_scores, steps, mode)

    # 스텝별 이미지 시각화
    visualize_insertion_deletion_steps(insertion_images, mode='insertion', steps=steps)
    visualize_insertion_deletion_steps(deletion_images, mode='deletion', steps=steps)

# 실행 예시
model_name = 'densenet'
image_path = '/home/xai/son/MLRSNet/Images/basketball_court/basketball_court_01327.jpg'

class_labels = ['airplane','airport','bare soil','baseball diamond','basketball court','beach','bridge','buildings','cars','chaparral','cloud','containers','crosswalk','dense residential area','desert','dock','factory','field','football field','forest','freeway','golf course','grass','greenhouse','gully','habor','intersection','island','lake','mobile home','mountain','overpass','park','parking lot','parkway','pavement','railway','railway station','river','road','roundabout','runway','sand','sea','ships','snow','snowberg','sparse residential area','stadium','swimming pool','tanks','tennis court','terrace','track','trail','transmission tower','trees','water','wetland','wind turbine']

# 모델 로드
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

model = DenseNetModel.load_from_checkpoint(
    checkpoint_path='/home/xai/son/src/checkpoints/densenet_final_epoch_20241023_204532.ckpt',
    num_classes=config['data']['num_classes'],
    learning_rate=config['train']['learning_rate']
)
# Grad-CAM 수행
cam_mask = apply_grad_cam(model, image_path, model_name)

step = 10
# Insertion 및 Deletion 수행 및 시각화 (Grad-CAM 기반, 클래스 이름 출력)
apply_insertion_deletion_visualize(model, image_path, cam_mask=cam_mask, steps=step, mode='gradcam', class_labels=class_labels)

# Insertion 및 Deletion 수행 및 시각화 (LIME 기반, 클래스 이름 출력)
apply_insertion_deletion_visualize(model, image_path, steps=step, mode='lime', class_labels=class_labels)
