In [None]:
import sys
user_site = '/home/xai/.local/lib/python3.10/site-packages'
if user_site in sys.path:
    sys.path.remove(user_site)

import os
os.environ.pop('MPLBACKEND', None)

import matplotlib
matplotlib.use('agg')

import site
site.ENABLE_USER_SITE = False

import os
print(os.environ.get('PYTHONNOUSERSITE'))
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from lime import lime_image
from skimage.segmentation import mark_boundaries
import shap
import yaml
import pytorch_lightning as pl
from densenet201 import DenseNetModel  # 수정된 임포트 경로
from torch import nn
from torchmetrics import Accuracy

# 랜덤 시드 고정 함수
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 시드 설정
seed = 42
set_seed(seed)

# XAI 관련 함수 정의
def get_target_layer(model, model_name):
    if model_name == 'resnet':
        return model.model.layer4[-1]
    elif model_name == 'googlenet':
        return model.model.inception5b
    elif model_name == 'efficientnet':
        return model.model.features[-1]
    elif model_name == 'densenet':
        return model.model.features[-1]
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

def display_image(image, title):
    plt.imshow(image)
    plt.axis('off')
    plt.title(title)
    plt.show()

def apply_grad_cam(model, image_path, model_name):
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # 데이터셋에 맞게 수정
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    img = Image.open(image_path)
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
    model.to('cuda' if torch.cuda.is_available() else 'cpu')
    
    target_layer = get_target_layer(model, model_name)
    cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available())
    grayscale_cam = cam(input_tensor=img_tensor)[0, :]
    
    img = np.array(img_resized) / 255.0
    visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
    
    display_image(visualization, 'Grad-CAM')

def apply_lime(model, image_path):
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # 데이터셋에 맞게 수정
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    img = Image.open(image_path)
    img_resized = img.resize((256, 256))
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

    # 관심 있는 클래스 인덱스 지정 (예: 클래스 0)
    class_idx = 0

    def batch_predict(images):
        model.eval()
        batch = torch.stack([transform(Image.fromarray(image)) for image in images], dim=0)
        batch = batch.to('cuda' if torch.cuda.is_available() else 'cpu')
        logits = model(batch)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        return probs

    explainer = lime_image.LimeImageExplainer(random_state=seed)
    explanation = explainer.explain_instance(
        np.array(img_resized),
        batch_predict,
        labels=(class_idx,),
        hide_color=0,
        num_samples=1000
    )
    
    temp, mask = explanation.get_image_and_mask(
        label=class_idx,
        positive_only=True,
        num_features=5,
        hide_rest=False
    )
    img_boundry = mark_boundaries(temp / 255.0, mask)
    
    # 중요 영역만 남기고 나머지는 흰색 처리
    white_image = np.ones_like(temp) * 255
    only_important = np.copy(white_image)
    only_important[mask] = temp[mask]

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_boundry)
    ax[0].axis('off')
    ax[0].set_title('LIME')
    
    ax[1].imshow(only_important)
    ax[1].axis('off')
    ax[1].set_title('Important Areas Only')
    
    plt.show()

def apply_gradient_shap(model, image_path):
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    img = Image.open(image_path)
    img_resized = img.resize((256, 256))
    img_tensor = transform(img_resized).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
    model.to('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 관심 있는 클래스 인덱스 지정 (예: 클래스 0)
    class_idx = 0

    # SHAP Explainer 생성
    background = torch.cat([img_tensor * 0, img_tensor * 1], dim=0)  # 배경 데이터 생성
    e = shap.GradientExplainer(model, background)
    
    # SHAP 값 계산
    shap_values = e.shap_values(img_tensor, nsamples=200, targets=torch.tensor([class_idx]))
    
    # SHAP 값 시각화
    shap.image_plot(shap_values, np.array(img_resized).astype(np.uint8))

def apply_all_xai_methods(model, image_path, model_name):
    print("Applying Grad-CAM")
    apply_grad_cam(model, image_path, model_name)
    
    print("Applying LIME")
    apply_lime(model, image_path)
    
    print("Applying Gradient SHAP")
    apply_gradient_shap(model, image_path)

# 예시 사용법
model_name = 'densenet'
image_path = '/home/xai/son/MLRSNet/Images/bareland/bareland_00002.jpg'

# config.yaml 파일에서 설정 로드
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# 학습된 모델 로드
model = DenseNetModel.load_from_checkpoint(
    checkpoint_path='/home/xai/son/src/checkpoints/resnet_final_epoch_20241022_233321.pth',
    num_classes=config['data']['num_classes'],
    learning_rate=config['train']['learning_rate']
)

apply_all_xai_methods(model, image_path, model_name)
