<a href="https://colab.research.google.com/github/ykitaguchi77/CorneAI/blob/main/YOLOv5-LIME-RISE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**YOLOv5 LIME CorneAI**

##**Setup YOLOv5**

In [203]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [204]:
%cd /content
!pip uninstall deep_utils -y
!pip install -U git+https://github.com/pooya-mohammadi/deep_utils.git --q
!pip install torch --q
!pip install torchvision --q
!pip install -U opencv-python --q
print("[INFO] To use new installed version of opencv, the session should be restarted!!!!")

!git clone https://github.com/pooya-mohammadi/yolov5-gradcam

/content


NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968

In [None]:
import os
os.chdir('/content/yolov5-gradcam')

model_path = "/gdrive/MyDrive/Deep_learning/CorneAI_nagoya/yolo5_forcresco/weights/eye_nii_2202_onecaseoneimage2_doctorcompare_yolov5s_epoch200_batch16_89.8p/last.pt"
img_path = "/gdrive/MyDrive/研究/進行中の研究/角膜スマートフォンAIプロジェクト/前原の240問/フォトスリット_serial/3.jpg"

#**LIME**

In [None]:
!pip install lime --q
!pip install scikit-image --q


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for lime (setup.py) ... [?25l[?25hdone


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
import cv2
import time
import traceback
import torchvision
from lime import lime_image
from deep_utils.utils.box_utils.boxes import Box
from models.experimental import attempt_load
from utils.general import xywh2xyxy
from utils.datasets import letterbox
from utils.metrics import box_iou
%matplotlib inline

# 角膜AIのクラス定義
CORNEA_CLASSES = [
    "infection",
    "normal",
    "non-infection",
    "scar",
    "tumor",
    "deposit",
    "APAC",
    "lens opacity",
    "bullous"
]

def setup_device():
    """
    GPUが利用可能な場合はGPUを、そうでない場合はCPUを設定

    Returns:
    --------
    device : torch.device
        使用するデバイス
    """
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    else:
        device = torch.device('cpu')
        print("GPU not available, using CPU")
    return device

class YOLOV5TorchObjectDetector(nn.Module):
    def __init__(self,
                 model_weight,
                 device,
                 img_size=(640, 640),
                 names=CORNEA_CLASSES,
                 mode='eval',
                 confidence=0.25,
                 iou_thresh=0.45,
                 agnostic_nms=False):
        super(YOLOV5TorchObjectDetector, self).__init__()
        self.device = device
        self.model = None
        self.img_size = img_size
        self.mode = mode
        self.confidence = confidence
        self.iou_thresh = iou_thresh
        self.agnostic = agnostic_nms

        # モデルのロード
        print("[INFO] Loading cornea detection model...")
        self.model = attempt_load(model_weight, device=device)
        print("[INFO] Model loaded successfully")

        # モデルのクラス数を取得と確認
        self.nc = int(self.model.nc)
        print(f"[INFO] Number of classes: {self.nc}")

        # クラス名の設定と検証
        self.names = names
        if len(self.names) != self.nc:
            print(f"[WARNING] Number of class names ({len(self.names)}) does not match model classes ({self.nc})")
        print(f"[INFO] Using class names: {self.names}")

        self.model.requires_grad_(True)
        self.model.to(device)
        if self.mode == 'train':
            self.model.train()
        else:
            self.model.eval()

        # Cold start prevention
        print("[INFO] Performing cold start prevention...")
        img = torch.zeros((1, 3, *self.img_size), device=device)
        self.model(img)
        print("[INFO] Initialization complete")

    @staticmethod
    def non_max_suppression(prediction, logits, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False,
                            multi_label=False, labels=(), max_det=300):
        nc = prediction.shape[2] - 5
        xc = prediction[..., 4] > conf_thres

        # Checks
        assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
        assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

        # Settings
        min_wh, max_wh = 2, 4096
        max_nms = 30000
        time_limit = 10.0
        redundant = True
        multi_label &= nc > 1
        merge = False

        t = time.time()
        output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
        logits_output = [torch.zeros((0, nc), device=logits.device)] * logits.shape[0]

        for xi, (x, log_) in enumerate(zip(prediction, logits)):
            x = x[xc[xi]]
            log_ = log_[xc[xi]]

            if not x.shape[0]:
                continue

            x[:, 5:] *= x[:, 4:5]
            box = xywh2xyxy(x[:, :4])

            if multi_label:
                i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
                x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
            else:
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
                log_ = log_[conf.view(-1) > conf_thres]

            if classes is not None:
                x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

            n = x.shape[0]
            if not n:
                continue
            elif n > max_nms:
                x = x[x[:, 4].argsort(descending=True)[:max_nms]]

            c = x[:, 5:6] * (0 if agnostic else max_wh)
            boxes, scores = x[:, :4] + c, x[:, 4]
            i = torchvision.ops.nms(boxes, scores, iou_thres)

            if i.shape[0] > max_det:
                i = i[:max_det]

            output[xi] = x[i]
            logits_output[xi] = log_[i]

            if (time.time() - t) > time_limit:
                print(f'WARNING: NMS time limit {time_limit}s exceeded')
                break

        return output, logits_output

    @staticmethod
    def yolo_resize(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
        return letterbox(img, new_shape=new_shape, color=color, auto=auto, scaleFill=scaleFill, scaleup=scaleup)

    def forward(self, img):
        try:
            prediction, logits, _ = self.model(img, augment=False)
            prediction, logits = self.non_max_suppression(prediction, logits, self.confidence, self.iou_thresh,
                                                          classes=None,
                                                          agnostic=self.agnostic)

            batch_size = img.shape[0]
            self.boxes = [[] for _ in range(batch_size)]
            self.class_names = [[] for _ in range(batch_size)]
            self.classes = [[] for _ in range(batch_size)]
            self.confidences = [[] for _ in range(batch_size)]

            for i, det in enumerate(prediction):
                if len(det):
                    for *xyxy, conf, cls in det:
                        xyxy[0] = max(0, xyxy[0])
                        xyxy[1] = max(0, xyxy[1])
                        xyxy[2] = min(self.img_size[1], xyxy[2])
                        xyxy[3] = min(self.img_size[0], xyxy[3])

                        bbox = Box.box2box(xyxy,
                                           in_source=Box.BoxSource.Torch,
                                           to_source=Box.BoxSource.Numpy,
                                           return_int=True)

                        self.boxes[i].append(bbox)
                        self.confidences[i].append(float(conf.item()))
                        cls_idx = int(cls.item())

                        if cls_idx >= len(self.names):
                            print(f"[WARNING] Class index {cls_idx} is out of range")
                            cls_idx = 0

                        self.classes[i].append(cls_idx)
                        self.class_names[i].append(self.names[cls_idx])

            return [self.boxes, self.classes, self.class_names, self.confidences], logits

        except Exception as e:
            print(f"Error in forward pass: {e}")
            traceback.print_exc()
            return [[[]], [[]], [[]], [[]]], None

    def preprocessing(self, img):
        try:
            if len(img.shape) != 4:
                img = np.expand_dims(img, axis=0)
            im0 = img.astype(np.uint8)
            img = np.array([self.yolo_resize(im, new_shape=self.img_size)[0] for im in im0])
            img = img.transpose((0, 3, 1, 2))
            img = np.ascontiguousarray(img)
            img = torch.from_numpy(img).to(self.device)
            img = img / 255.0
            return img
        except Exception as e:
            print(f"Error in preprocessing: {e}")
            traceback.print_exc()
            return None

class YOLOLimeExplainer:
    def __init__(self, yolo_model, device='cuda', img_size=(640, 640)):
        self.model = yolo_model
        self.device = device
        self.img_size = img_size
        self.explainer = lime_image.LimeImageExplainer()

    def predict_fn(self, images):
        try:
            batch_predictions = []
            for img in images:
                processed_img = self.model.preprocessing(np.expand_dims(img, 0))
                if processed_img is None:
                    raise ValueError("Failed to preprocess image")

                with torch.no_grad():
                    predictions, _ = self.model(processed_img)

                class_scores = np.zeros(len(CORNEA_CLASSES))

                if predictions[0][0]:
                    for cls_idx, conf in zip(predictions[1][0], predictions[3][0]):
                        if cls_idx < len(CORNEA_CLASSES):
                            class_scores[cls_idx] = max(class_scores[cls_idx], conf)

                batch_predictions.append(class_scores)

            return np.array(batch_predictions)

        except Exception as e:
            print(f"Error in prediction_fn: {e}")
            traceback.print_exc()
            return np.zeros((len(images), len(CORNEA_CLASSES)))

    def explain_instance(self, image, num_samples=1000, top_labels=5):
        try:
            if isinstance(image, Image.Image):
                image = np.array(image)
            if len(image.shape) == 2:
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            elif image.shape[2] == 4:
                image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

            explanation = self.explainer.explain_instance(
                image,
                self.predict_fn,
                labels=range(len(CORNEA_CLASSES)),
                top_labels=top_labels,
                hide_color=0,
                num_samples=num_samples
            )
            return explanation

        except Exception as e:
            print(f"Error in explain_instance: {e}")
            traceback.print_exc()
            return None

def visualize_results(explanation, image, class_names=CORNEA_CLASSES, save_path=None):
    """
    Visualize LIME explanation results overlaid on the original image

    Args:
        explanation: LIME explanation object
        image: Original image array
        class_names: List of class names
        save_path: Optional path to save the visualization
    """
    try:
        # Get predictions
        prediction = explanation.predict_fn(np.array([image]))[0]

        # Sort labels by prediction confidence
        sorted_labels = sorted(range(len(prediction)),
                             key=lambda x: prediction[x],
                             reverse=True)[:2]

        # Create a figure with subplots
        plt.figure(figsize=(15, 7))

        for idx, label in enumerate(sorted_labels):
            plt.subplot(1, len(sorted_labels), idx + 1)

            # Get the original image and mask from LIME
            mask = explanation.local_exp[label]

            # Convert the sparse mask to a dense array
            dense_mask = np.zeros(explanation.segments.shape, dtype=float)
            for i, v in mask:
                dense_mask[explanation.segments == i] = v

            # Normalize the mask to [0, 1] range
            if dense_mask.max() != dense_mask.min():
                dense_mask = (dense_mask - dense_mask.min()) / (dense_mask.max() - dense_mask.min())

            # Create a colormap (red for positive contributions)
            heatmap = np.zeros((dense_mask.shape[0], dense_mask.shape[1], 4))
            heatmap[:, :, 2] = dense_mask  # Blue channel
            heatmap[:, :, 3] = dense_mask * 1.0  # Alpha channel

            # Display original image
            plt.imshow(image, alpha=0.8)

            # Overlay heatmap
            plt.imshow(heatmap, alpha=0.6)

            class_name = class_names[label] if label < len(class_names) else f"Unknown Class {label}"
            plt.title(f'{class_name}\nConfidence: {prediction[label]:.3f}', fontsize=12)
            plt.axis('off')

        plt.tight_layout()

        # if save_path:
        #     plt.savefig(save_path, bbox_inches='tight', dpi=300)
        #     print(f"Visualization saved to {save_path}")

        plt.show()

        print("\nDetailed Confidence Scores:")
        for label in sorted_labels:
            print(f"{class_names[label]}: {prediction[label]:.3f}")

    except Exception as e:
        print(f"Error in visualization: {e}")
        traceback.print_exc()

class YOLOLimeExplainer:
    def __init__(self, yolo_model, device='cuda', img_size=(640, 640)):
        self.model = yolo_model
        self.device = device
        self.img_size = img_size
        self.explainer = lime_image.LimeImageExplainer()

    def predict_fn(self, images):
        """
        Modified prediction function with improved image processing
        """
        try:
            batch_predictions = []
            for img in images:
                # Normalize image if needed
                if img.max() > 1.0:
                    img = img / 255.0

                # Convert image format if needed
                if len(img.shape) == 2:
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
                elif img.shape[2] == 4:
                    img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)

                # Ensure image is uint8 for preprocessing
                img_uint8 = (img * 255).astype(np.uint8)
                processed_img = self.model.preprocessing(np.expand_dims(img_uint8, 0))

                if processed_img is None:
                    raise ValueError("Failed to preprocess image")

                with torch.no_grad():
                    predictions, logits = self.model(processed_img)

                class_scores = np.zeros(len(CORNEA_CLASSES))

                if predictions[0][0]:  # If there are any detections
                    if logits is not None and len(logits[0]) > 0:
                        # Use logits if available
                        probs = torch.nn.functional.softmax(logits[0], dim=1)
                        class_scores = probs.mean(dim=0).cpu().numpy()
                    else:
                        # Fallback to confidence scores
                        for cls_idx, conf in zip(predictions[1][0], predictions[3][0]):
                            if cls_idx < len(CORNEA_CLASSES):
                                class_scores[cls_idx] = max(class_scores[cls_idx], conf)

                batch_predictions.append(class_scores)

            return np.array(batch_predictions)

        except Exception as e:
            print(f"Error in prediction_fn: {e}")
            traceback.print_exc()
            return np.zeros((len(images), len(CORNEA_CLASSES)))

    def explain_instance(self, image, num_samples=1000, top_labels=5):
        """
        Generate LIME explanation with improved image handling
        """
        try:
            # Convert PIL Image to numpy array if needed
            if isinstance(image, Image.Image):
                image = np.array(image)

            # Normalize image if needed
            if image.max() > 1.0:
                image = image.astype(float) / 255.0

            # Convert image format if needed
            if len(image.shape) == 2:
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            elif image.shape[2] == 4:
                image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

            # Generate explanation
            explanation = self.explainer.explain_instance(
                image,
                self.predict_fn,
                labels=range(len(CORNEA_CLASSES)),
                top_labels=top_labels,
                hide_color=0,
                num_samples=num_samples
            )

            # Store predict_fn for later use
            explanation.predict_fn = self.predict_fn

            return explanation

        except Exception as e:
            print(f"Error in explain_instance: {e}")
            traceback.print_exc()
            return None

def run_lime_analysis(model_path, img_path, num_samples=500, save_path=None):
    """
    Run LIME analysis with improved visualization

    Args:
        model_path: Path to the YOLO model weights
        img_path: Path to the input image
        num_samples: Number of samples for LIME analysis
        save_path: Optional path to save the visualization
    """
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")

        print("Loading YOLO model...")
        yolo_model = YOLOV5TorchObjectDetector(
            model_weight=model_path,
            device=device,
            img_size=(640, 640),
            names=CORNEA_CLASSES
        )

        print("Initializing LIME explainer...")
        lime_explainer = YOLOLimeExplainer(yolo_model)

        print("Loading and processing image...")
        image = Image.open(img_path)
        image_array = np.array(image)

        # Normalize image if needed
        if image_array.max() > 1.0:
            image_array = image_array.astype(float) / 255.0

        print(f"Running LIME analysis with {num_samples} samples...")
        explanation = lime_explainer.explain_instance(
            image_array,
            num_samples=num_samples,
            top_labels=3
        )

        if explanation is None:
            print("Failed to generate explanation")
            return None

        print("Visualizing results...")
        visualize_results(explanation, image_array, save_path=save_path)

        return explanation

    except Exception as e:
        print(f"Error in run_lime_analysis: {e}")
        traceback.print_exc()
        return None

# Usage example:
if __name__ == "__main__":
    model_path = "/gdrive/MyDrive/Deep_learning/CorneAI_nagoya/yolo5_forcresco/weights/eye_nii_2202_onecaseoneimage2_doctorcompare_yolov5s_epoch200_batch16_89.8p/last.pt"
    img_path = "/gdrive/MyDrive/研究/進行中の研究/角膜スマートフォンAIプロジェクト/前原の240問/フォトスリット_serial/1.jpg"
    save_path = "lime_explanation.png"  # Optional

    explanation = run_lime_analysis(
        model_path=model_path,
        img_path=img_path,
        num_samples=500,
        save_path=save_path
    )

#**RISE**

In [163]:
class YOLOv5TorchObjectDetector(nn.Module):
    def __init__(self,
                 model_weight,
                 device,
                 img_size,
                 names=None,
                 mode='eval',
                 confidence=0.25,
                 iou_thresh=0.45,
                 agnostic_nms=False):
        super(YOLOv5TorchObjectDetector, self).__init__()
        self.device = device
        self.model = None
        self.img_size = img_size
        self.mode = mode
        self.confidence = confidence
        self.iou_thresh = iou_thresh
        self.agnostic = agnostic_nms
        self.model = attempt_load(model_weight, device=device)
        print("[INFO] Model is loaded")
        self.model.requires_grad_(True)
        self.model.to(device)
        if self.mode == 'train':
            self.model.train()
        else:
            self.model.eval()
        # fetch the names
        if names is None:
            print('[INFO] fetching names from coco file')
            self.names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
                          'traffic light',
                          'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
                          'cow',
                          'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
                          'frisbee',
                          'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
                          'surfboard',
                          'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
                          'apple',
                          'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
                          'couch',
                          'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
                          'keyboard', 'cell phone',
                          'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
                          'teddy bear',
                          'hair drier', 'toothbrush']
        else:
            self.names = names

        # preventing cold start
        img = torch.zeros((1, 3, *self.img_size), device=device)
        self.model(img)

    @staticmethod
    def non_max_suppression(prediction, logits, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False,
                            multi_label=False, labels=(), max_det=300):
        """Runs Non-Maximum Suppression (NMS) on inference and logits results

        Returns:
             list of detections, on (n,6) tensor per image [xyxy, conf, cls] and pruned input logits (n, number-classes)
        """

        nc = prediction.shape[2] - 5  # number of classes
        xc = prediction[..., 4] > conf_thres  # candidates

        # Checks
        assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
        assert 0 <= iou_thresh <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

        # Settings
        min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
        max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
        time_limit = 10.0  # seconds to quit after
        redundant = True  # require redundant detections
        multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
        merge = False  # use merge-NMS

        t = time.time()
        output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
        logits_output = [torch.zeros((0, 80), device=logits.device)] * logits.shape[0]
        for xi, (x, log_) in enumerate(zip(prediction, logits)):  # image index, image inference
            # Apply constraints
            x = x[xc[xi]]  # confidence
            log_ = log_[xc[xi]]
            # Cat apriori labels if autolabelling
            if labels and len(labels[xi]):
                l = labels[xi]
                v = torch.zeros((len(l), nc + 5), device=x.device)
                v[:, :4] = l[:, 1:5]  # box
                v[:, 4] = 1.0  # conf
                v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
                x = torch.cat((x, v), 0)

            # If none remain process next image
            if not x.shape[0]:
                continue

            # Compute conf
            x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
            # Box (center x, center y, width, height) to (x1, y1, x2, y2)
            box = xywh2xyxy(x[:, :4])

            # Detections matrix nx6 (xyxy, conf, cls)
            if multi_label:
                i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
                x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
            else:  # best class only
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
                log_ = log_[conf.view(-1) > conf_thres]
            # Filter by class
            if classes is not None:
                x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

            # Check shape
            n = x.shape[0]  # number of boxes
            if not n:  # no boxes
                continue
            elif n > max_nms:  # excess boxes
                x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

            # Batched NMS
            c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
            boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
            i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
            if i.shape[0] > max_det:  # limit detections
                i = i[:max_det]
            if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
                # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
                iou = box_iou(boxes[i], boxes) > iou_thresh  # iou matrix
                weights = iou * scores[None]  # box weights
                x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
                if redundant:
                    i = i[iou.sum(1) > 1]  # require redundancy

            output[xi] = x[i]
            logits_output[xi] = log_[i]
            assert log_[i].shape[0] == x[i].shape[0]
            if (time.time() - t) > time_limit:
                print(f'WARNING: NMS time limit {time_limit}s exceeded')
                break  # time limit exceeded

        return output, logits_output

    @staticmethod
    def yolo_resize(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
        return letterbox(img, new_shape=new_shape, color=color, auto=auto, scaleFill=scaleFill, scaleup=scaleup)

    def forward(self, img):
        prediction, logits, _ = self.model(img, augment=False)
        prediction, logits = self.non_max_suppression(prediction, logits, self.confidence, self.iou_thresh,
                                                      classes=None,
                                                      agnostic=self.agnostic)
        self.boxes, self.class_names, self.classes, self.confidences = [[[] for _ in range(img.shape[0])] for _ in
                                                                        range(4)]
        for i, det in enumerate(prediction):  # detections per image
            if len(det):
                for *xyxy, conf, cls in det:
                    # xyxyの値を修正
                    xyxy[0] = max(0, xyxy[0])
                    xyxy[1] = max(0, xyxy[1])
                    xyxy[2] = min(self.img_size, xyxy[2])  # 修正箇所
                    xyxy[3] = min(self.img_size, xyxy[3])  # 修正箇所

                    bbox = Box.box2box(xyxy,
                                       in_source=Box.BoxSource.Torch,
                                       to_source=Box.BoxSource.Numpy,
                                       return_int=True)
                    self.boxes[i].append(bbox)
                    self.confidences[i].append(round(conf.item(), 2))
                    cls = int(cls.item())
                    self.classes[i].append(cls)
                    if self.names is not None:
                        self.class_names[i].append(self.names[cls])
                    else:
                        self.class_names[i].append(cls)
        return [self.boxes, self.classes, self.class_names, self.confidences], logits

    def preprocessing(self, img):
        if len(img.shape) != 4:
            img = np.expand_dims(img, axis=0)
        im0 = img.astype(np.uint8)
        img = np.array([self.yolo_resize(im, new_shape=self.img_size)[0] for im in im0])
        img = img.transpose((0, 3, 1, 2))
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(self.device)
        img = img / 255.0
        return img

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.transform import resize
from tqdm import tqdm

import os
import time
import numpy as np
import cv2
from google.colab.patches import cv2_imshow
from deep_utils import Box, split_extension
import gc

import torch
import torch.nn.functional as F
import torchvision
import torch.nn as nn
from utils.general import xywh2xyxy
from utils.datasets import letterbox
from utils.metrics import box_iou
from models.experimental import attempt_load

# クラス名リストを指定
names = ["infection", "normal", "non-infection", "scar", "tumor", "deposit", "APAC", "lens opacity", "bullous"]

class RISEForYOLOv5(nn.Module):
    def __init__(self, model, detector, device, gpu_batch=20):
        super(RISEForYOLOv5, self).__init__()
        self.model = model
        self.detector = detector
        self.device = device
        self.gpu_batch = gpu_batch
        self.masks = None
        self.N = None
        self.p1 = None
        self.input_size = None  # 初期値をNoneに設定

    def forward(self, img, boxes, scores, classes):
        B, C, H, W = img.size()
        self.input_size = (H, W)  # 入力画像サイズを取得

        # ここで画像をリサイズし、パディング情報を取得
        resized_imgs = []
        pad_info = []
        for b in range(B):
            img_np = img[b].cpu().numpy().transpose(1, 2, 0)
            img_np, ratio, (dw, dh) = letterbox(img_np, new_shape=self.input_size, auto=False, scaleFill=False)
            resized_imgs.append(img_np)
            pad_info.append((ratio, dw, dh))

        # リサイズされた画像をテンソルに変換
        img_resized = np.stack(resized_imgs, axis=0)
        img_resized = img_resized.transpose(0, 3, 1, 2)
        img_resized = torch.from_numpy(img_resized).to(self.device).float()

        # 以降の処理では img_resized を使用
        saliency_maps = []

        for b in range(B):
            if len(boxes[b]) == 0 or len(scores[b]) == 0:
                print(f"[INFO] No detections for batch {b}")
                saliency_maps.append(None)
                continue

            try:
                scores_tensor = torch.tensor(scores[b])
                if scores_tensor.numel() == 0:
                    print(f"[INFO] Empty scores tensor for batch {b}")
                    saliency_maps.append(None)
                    continue

                top1_idx = torch.argmax(scores_tensor)
                cls = classes[b][top1_idx]
                score = scores[b][top1_idx]
                bbox = boxes[b][top1_idx]

                # バウンディングボックスの座標をリサイズ・パディングに合わせて変換
                ratio, dw, dh = pad_info[b]
                bbox_resized = self.adjust_bbox(bbox, ratio, dw, dh)

                saliency_map = self.apply_rise(img_resized[b], cls, score, bbox_resized)

                saliency_maps.append(saliency_map)
            except Exception as e:
                print(f"[ERROR] Error processing batch {b}: {str(e)}")
                saliency_maps.append(None)

        return saliency_maps

    def adjust_bbox(self, bbox, ratio, dw, dh):
        # 元のバウンディングボックスの座標をリサイズ・パディングに合わせて変換
        x1, y1, x2, y2 = bbox
        x1 = x1 * ratio[0] + dw
        x2 = x2 * ratio[0] + dw
        y1 = y1 * ratio[1] + dh
        y2 = y2 * ratio[1] + dh
        return [int(x1), int(y1), int(x2), int(y2)]


    def apply_rise(self, img, cls, score, bbox):
        N = self.N
        _, H, W = img.size()
        self.input_size = (H, W)

        if self.masks is None or self.masks.shape[2:] != (H, W):
            # マスクを再生成
            print(f"[INFO] Regenerating masks for size ({H}, {W})")
            self.generate_masks(N=self.N, s=8, p1=0.1, img_size=(H, W))

        stack = torch.mul(self.masks, img.unsqueeze(0))  # N x C x H x W

        scores = torch.zeros(N, device=self.device)

        for i in range(0, N, self.gpu_batch):
            with torch.no_grad():
                batch_imgs = stack[i:min(i + self.gpu_batch, N)]
                prediction, logits, _ = self.model(batch_imgs, augment=False)

                outputs, logits_output = self.detector.non_max_suppression(
                    prediction,
                    logits,
                    conf_thres=self.detector.confidence,
                    iou_thres=self.detector.iou_thresh,
                    classes=None,
                    agnostic=self.detector.agnostic
                )

                for j, detections in enumerate(outputs):
                    idx = i + j
                    if detections is not None and len(detections):
                        detections = detections[detections[:, 5] == cls]
                        if len(detections):
                            bbox_tensor = torch.tensor([bbox], device=self.device).float()
                            ious = box_iou(detections[:, :4], bbox_tensor)
                            max_iou, max_idx = ious.max(0)
                            if max_iou > 0.5:
                                scores[idx] = detections[max_idx, 4]

        saliency = torch.matmul(scores.unsqueeze(0), self.masks.view(N, -1))  # 1 x (H*W)
        saliency = saliency.view(H, W)
        saliency = saliency / (N * self.p1)
        saliency_map = saliency.cpu().numpy()

        if saliency_map.max() > saliency_map.min():
            saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())
        else:
            saliency_map = np.zeros_like(saliency_map)

        return saliency_map


    def generate_masks(self, N, s, p1, img_size=None, savepath=None):
        if img_size is None:
            if self.input_size is not None:
                img_size = self.input_size  # 既に設定されている場合
            else:
                raise ValueError("img_size must be provided if self.input_size is not set.")
        else:
            self.input_size = img_size  # img_sizeをself.input_sizeに設定

        # 以降のコードはそのまま
        cell_size = np.ceil(np.array(img_size) / s)
        up_size = (s + 1) * cell_size

        grid = np.random.rand(N, s, s) < p1
        grid = grid.astype('float32')

        self.masks = np.empty((N, *img_size))

        for i in tqdm(range(N), desc='Generating masks'):
            upsampled = cv2.resize(grid[i].astype(np.float32),
                                   (img_size[1], img_size[0]),
                                   interpolation=cv2.INTER_LINEAR)
            self.masks[i] = upsampled

        self.masks = self.masks.reshape(-1, 1, *img_size)
        if savepath is not None:
            np.save(savepath, self.masks)
        self.masks = torch.from_numpy(self.masks).float().to(self.device)
        self.N = N
        self.p1 = p1

    def load_masks(self, filepath):
        self.masks = np.load(filepath)
        self.masks = torch.from_numpy(self.masks).float().to(self.device)
        self.N = self.masks.shape[0]
        self.input_size = self.masks.shape[2:]  # マスクのサイズから入力サイズを取得


    def crop_image(self, img, bbox):
        """
        検出されたボックスに基づいて画像をクロップし、letterboxで固定サイズにリサイズします。
        """
        x1, y1, x2, y2 = map(int, bbox)
        # ボックスが画像内に収まるように調整
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(self.input_size[1], x2)  # width
        y2 = min(self.input_size[0], y2)  # height

        if x2 <= x1 or y2 <= y1:
            return torch.tensor([]).to(self.device), (x1, y1, x2, y2)

        # パディングを追加して完全なサイズを維持
        cropped = img[:, :, :]  # 全体の画像を保持
        # letterboxを適用
        cropped_np = cropped.cpu().numpy().transpose(1, 2, 0)
        cropped_np = (cropped_np * 255.0).astype(np.uint8)

        # letterboxの適用
        cropped_np = letterbox(cropped_np, new_shape=self.input_size, stride=32, auto=False)[0]

        cropped_np = cropped_np[:, :, ::-1].transpose(2, 0, 1)
        cropped_np = np.ascontiguousarray(cropped_np)
        cropped_tensor = torch.from_numpy(cropped_np).float().to(self.device) / 255.0
        if cropped_tensor.ndimension() == 3:
            cropped_tensor = cropped_tensor.unsqueeze(0)
        return cropped_tensor.squeeze(0), (x1, y1, x2, y2)



    def expand_saliency_map(self, saliency_map, bbox_coords, H, W):
        """
        クロップしたサリエンシーマップを元の画像サイズに拡大します。
        """
        x1, y1, x2, y2 = bbox_coords

        # 座標を画像サイズ内に制限
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(W, x2), min(H, y2)

        # バウンディングボックスのサイズを計算
        box_height = y2 - y1
        box_width = x2 - x1

        if box_height <= 0 or box_width <= 0:
            print(f"Invalid box dimensions: width={box_width}, height={box_height}")
            return np.zeros((H, W), dtype=np.float32)

        # cv2.resizeを使用して正確なサイズにリサイズ
        try:
            sal_map_resized = cv2.resize(saliency_map, (box_width, box_height),
                                      interpolation=cv2.INTER_LINEAR)
        except Exception as e:
            print(f"Resize error: {e}")
            print(f"Source shape: {saliency_map.shape}")
            print(f"Target shape: ({box_height}, {box_width})")
            return np.zeros((H, W), dtype=np.float32)

        # 出力マップを初期化
        sal_map_full = np.zeros((H, W), dtype=np.float32)

        try:
            sal_map_full[y1:y2, x1:x2] = sal_map_resized
        except ValueError as e:
            print(f"Assignment error: {e}")
            print(f"Resized shape: {sal_map_resized.shape}")
            print(f"Target region shape: ({y2-y1}, {x2-x1})")
            return np.zeros((H, W), dtype=np.float32)

        # 正規化
        if sal_map_full.max() > sal_map_full.min():
            sal_map_full = (sal_map_full - sal_map_full.min()) / (sal_map_full.max() - sal_map_full.min())

        return sal_map_full


class YOLOV5TorchObjectDetector(nn.Module):
    def __init__(self,
                 model_weight,
                 device,
                 img_size,
                 rise_explainer=None,  # RISEインスタンスを追加
                 names=None,
                 mode='eval',
                 confidence=0.25,
                 iou_thresh=0.45,
                 agnostic_nms=False):
        """
        YOLOv5オブジェクト検出器を初期化します。

        Args:
            model_weight (str): YOLOv5の重みファイルのパス
            device (str): 使用するデバイス ('cuda' または 'cpu')
            img_size (int or tuple): 入力画像のサイズ
            rise_explainer (RISEForYOLOv5, optional): RISEインスタンス
            names (list, optional): クラス名のリスト
            mode (str, optional): モード ('eval' または 'train')
            confidence (float, optional): 信頼度閾値
            iou_thresh (float, optional): IoU閾値
            agnostic_nms (bool, optional): クラス非依存のNMSを使用するか
        """
        super(YOLOV5TorchObjectDetector, self).__init__()
        self.device = device
        self.model = attempt_load(model_weight)  # map_locationを使用
        print("[INFO] YOLOv5 model loaded")
        self.model.to(device)
        self.model.eval()  # 推論モード
        self.rise_explainer = rise_explainer  # RISEインスタンスの保存

        self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
        self.confidence = confidence
        self.iou_thresh = iou_thresh
        self.agnostic = agnostic_nms

        # クラス名の設定
        if names is None:
            self.names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
                          'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
                          'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
                          'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
                          'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
                          'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
                          'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
                          'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
                          'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
                          'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
                          'toothbrush']
        else:
            self.names = names

        # 冷開始防止
        img = torch.zeros((1, 3, *self.img_size), device=device)
        with torch.no_grad():
            self.model(img)

    @staticmethod
    def non_max_suppression(prediction, logits, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False,
                            multi_label=False, labels=(), max_det=300):
        nc = prediction.shape[2] - 5  # number of classes
        xc = prediction[..., 4] > conf_thres  # candidates

        # Checks
        assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
        assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

        # Settings
        min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
        max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
        time_limit = 10.0  # seconds to quit after
        redundant = True  # require redundant detections
        multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
        merge = False  # use merge-NMS

        t = time.time()
        output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
        logits_output = [torch.zeros((0, 80), device=logits.device)] * logits.shape[0]

        for xi, (x, log_) in enumerate(zip(prediction, logits)):  # 画像インデックス、画像推論
            x = x[xc[xi]]  # 信頼度
            log_ = log_[xc[xi]]
            if labels and len(labels[xi]):
                l = labels[xi]
                v = torch.zeros((len(l), nc + 5), device=x.device)
                v[:, :4] = l[:, 1:5]  # ボックス
                v[:, 4] = 1.0  # conf
                v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
                x = torch.cat((x, v), 0)

            if not x.shape[0]:
                continue

            x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
            box = YOLOV5TorchObjectDetector.xywh2xyxy_custom(x[:, :4])

            if multi_label:
                i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
                x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
            else:  # best class only
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
                log_ = log_[conf.view(-1) > conf_thres]
            if classes is not None:
                x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

            n = x.shape[0]  # ボックス数
            if not n:  # ボックスなし
                continue
            elif n > max_nms:  # 過剰ボックス
                x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # 信頼度順にソート

            c = x[:, 5:6] * (0 if agnostic else 4096)  # クラス
            boxes, scores = x[:, :4] + c, x[:, 4]  # ボックス（クラスでオフセット）、スコア
            i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
            if i.shape[0] > max_det:  # 検出数制限
                i = i[:max_det]
            if merge and (1 < n < 3E3):  # Merge NMS（ボックスを重み付き平均でマージ）
                iou = box_iou(boxes[i], boxes) > iou_thres  # IoU行列
                weights = iou * scores[None]  # ボックス重み
                x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # マージされたボックス
                if redundant:
                    i = i[iou.sum(1) > 1]  # 冗長性を要求

            output[xi] = x[i]
            logits_output[xi] = log_[i]
            assert log_[i].shape[0] == x[i].shape[0]
            if (time.time() - t) > time_limit:
                print(f'WARNING: NMS time limit {time_limit}s exceeded')
                break  # 時間制限を超過

        return output, logits_output

    @staticmethod
    def xywh2xyxy_custom(x):
        """
        [x, y, w, h] を [x1, y1, x2, y2] に変換します。

        Args:
            x (torch.Tensor): [x, y, w, h]

        Returns:
            torch.Tensor: [x1, y1, x2, y2]
        """
        y = x.clone()
        y[:, 0] = x[:, 0] - x[:, 2] / 2  # x1
        y[:, 1] = x[:, 1] - x[:, 3] / 2  # y1
        y[:, 2] = x[:, 0] + x[:, 2] / 2  # x2
        y[:, 3] = x[:, 1] + x[:, 3] / 2  # y2
        return y

    @staticmethod
    def yolo_resize(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):

        return letterbox(img, new_shape=new_shape, color=color, auto=auto, scaleFill=scaleFill, scaleup=scaleup)

    def preprocessing(self, img):
        if len(img.shape) != 4:
            img = np.expand_dims(img, axis=0)
        im0 = img.astype(np.uint8)
        img = np.array([self.yolo_resize(im, new_shape=self.img_size)[0] for im in im0])
        img = img.transpose((0, 3, 1, 2))
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(self.device)
        img = img / 255.0
        return img

    def forward(self, img):
        with torch.no_grad():
            prediction, logits, _ = self.model(img, augment=False)
        prediction, logits = self.non_max_suppression(prediction, logits, self.confidence, self.iou_thresh,
                                                      classes=None,
                                                      agnostic=self.agnostic)

        batch_size = img.size(0)
        self.boxes = [[] for _ in range(batch_size)]
        self.class_names = [[] for _ in range(batch_size)]
        self.classes = [[] for _ in range(batch_size)]
        self.confidences = [[] for _ in range(batch_size)]

        for i, det in enumerate(prediction):  # detections per image
            if len(det):
                for *xyxy, conf, cls in det:
                    xyxy[0] = max(0, xyxy[0])
                    xyxy[1] = max(0, xyxy[1])
                    xyxy[2] = min(self.img_size[1], xyxy[2])
                    xyxy[3] = min(self.img_size[0], xyxy[3])

                    bbox = Box.box2box(xyxy,
                                       in_source=Box.BoxSource.Torch,
                                       to_source=Box.BoxSource.Numpy,
                                       return_int=True)
                    self.boxes[i].append(bbox)
                    self.confidences[i].append(conf.item())
                    cls = int(cls.item())
                    self.classes[i].append(cls)
                    if self.names is not None:
                        self.class_names[i].append(self.names[cls])
                    else:
                        self.class_names[i].append(cls)
            else:
                print(f"No detections for batch {i}")

        return [self.boxes, self.classes, self.class_names, self.confidences], logits

# 補助関数
# Also make sure the get_res_img function properly handles the ratio, dw, and dh parameters
def get_res_img(bbox, masks, res_img, ratio, dw, dh):
    """
    サリエンシーマップをオリジナル画像に重ね合わせます。

    Args:
        bbox (list): [x1, y1, x2, y2] - モデル出力のバウンディングボックス
        masks (list): サリエンシーマップのリスト
        res_img (np.ndarray): BGR形式の元画像
        ratio (tuple): (リサイズ時の幅比率, 高さ比率)
        dw (float): 横方向のパディング
        dh (float): 縦方向のパディング

    Returns:
        tuple: 重ね合わせた画像と最後のヒートマップ
    """
    heatmap_resized = None
    H, W = res_img.shape[:2]

    # バウンディングボックスの座標を元の画像サイズに変換
    x1, y1, x2, y2 = bbox
    x1 = (x1 - dw) / ratio[0]
    x2 = (x2 - dw) / ratio[0]
    y1 = (y1 - dh) / ratio[1]
    y2 = (y2 - dh) / ratio[1]

    # 整数に変換して範囲を調整
    x1, y1, x2, y2 = map(int, [max(0, x1), max(0, y1), min(W, x2), min(H, y2)])

    box_h = y2 - y1
    box_w = x2 - x1

    if box_h <= 0 or box_w <= 0:
        print(f"Invalid box dimensions: width={box_w}, height={box_h}")
        return res_img, np.zeros((1, 1, 3), dtype=np.uint8)

    for mask in masks:
        if mask is None:
            continue

        mask = mask.squeeze().astype(np.float32)
        if mask.max() > mask.min():
            mask = (mask - mask.min()) / (mask.max() - mask.min())

        try:
            resized_mask = cv2.resize(mask, (box_w, box_h), interpolation=cv2.INTER_LINEAR)

            full_mask = np.zeros((H, W), dtype=np.float32)
            full_mask[y1:y2, x1:x2] = resized_mask

            heat_colors = cv2.applyColorMap((full_mask * 255).astype(np.uint8), cv2.COLORMAP_JET)

            result_img = res_img.copy()
            alpha = 0.7
            beta = 0.5

            result_img[y1:y2, x1:x2] = cv2.addWeighted(
                res_img[y1:y2, x1:x2], alpha,
                heat_colors[y1:y2, x1:x2], beta, 0
            )

            gamma = 1.2
            result_img[y1:y2, x1:x2] = np.clip(
                result_img[y1:y2, x1:x2] * gamma, 0, 255
            ).astype(np.uint8)

            res_img = result_img
            heatmap_resized = heat_colors

        except Exception as e:
            print(f"Error in get_res_img: {e}")
            print(f"Mask shape: {mask.shape}")
            print(f"Box dimensions: width={box_w}, height={box_h}")

    return res_img, heatmap_resized if heatmap_resized is not None else np.zeros((1, 1, 3), dtype=np.uint8)






def put_text_box(bbox, cls_name, res_img):
    """
    バウンディングボックスとクラス名を画像に描画します。

    Args:
        bbox (list): [x1, y1, x2, y2]
        cls_name (str): クラス名
        res_img (np.ndarray): 画像

    Returns:
        np.ndarray: 描画後の画像
    """
    x1, y1, x2, y2 = bbox
    # ボックスを描画
    cv2.rectangle(res_img, (x1, y1), (x2, y2), (0, 255, 0), 2)

    # テキストの背景ボックス
    (text_width, text_height), _ = cv2.getTextSize(cls_name, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
    cv2.rectangle(res_img, (x1, y1 - text_height - 4), (x1 + text_width, y1), (0, 255, 0), -1)

    # テキストを描画
    cv2.putText(res_img, cls_name, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)

    return res_img

def concat_images(images):
    """
    画像を縦に連結します。

    Args:
        images (list): 連結する画像のリスト

    Returns:
        np.ndarray: 連結後の画像
    """
    # 有効な画像のみを連結
    valid_images = [img for img in images if img.size != 0]
    if not valid_images:
        return np.array([])
    return np.vstack(valid_images)

def main_with_rise(img_path, model, rise_explainer, names, output_dir):
    img = cv2.imread(img_path)
    if img is None:
        print(f"[ERROR] Failed to read image: {img_path}")
        return

    # 画像をリサイズしてパディング情報を取得
    img_resized, ratio, (dw, dh) = letterbox(img, new_shape=model.img_size, auto=False, scaleFill=False)

    # 画像をモデルの入力フォーマットに変換
    img_resized = img_resized[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB and HWC to CHW
    img_resized = np.ascontiguousarray(img_resized)
    torch_img = torch.from_numpy(img_resized).to(model.device).float().unsqueeze(0) / 255.0

    predictions, _ = model(torch_img)
    boxes, classes, confidences = predictions[0], predictions[1], predictions[3]

    if all(len(b) == 0 for b in boxes):
        print("[WARNING] No detections found in the image")
        return

    try:
        saliency_maps = rise_explainer(torch_img, boxes, confidences, classes)
    except Exception as e:
        print(f"[ERROR] Failed to generate saliency maps: {str(e)}")
        return

    original_bgr = img.copy()
    images = [original_bgr]

    for b, sal_map in enumerate(saliency_maps):
        if sal_map is None:
            continue

        try:
            if len(confidences[b]) == 0:
                continue

            top1_idx = torch.argmax(torch.tensor(confidences[b])).item()
            bbox = boxes[b][top1_idx]
            cls = classes[b][top1_idx]
            cls_name = names[cls] if names else str(cls)

            # バウンディングボックスの座標を元の画像サイズに変換
            x1, y1, x2, y2 = bbox
            x1 = (x1 - dw) / ratio[0]
            x2 = (x2 - dw) / ratio[0]
            y1 = (y1 - dh) / ratio[1]
            y2 = (y2 - dh) / ratio[1]
            bbox_original = [int(x1), int(y1), int(x2), int(y2)]

            vis_img = original_bgr.copy()
            # Pass all required arguments to get_res_img
            vis_img, heatmap = get_res_img(bbox_original, [sal_map], vis_img, ratio, dw, dh)
            vis_img = put_text_box(bbox_original, cls_name, vis_img)
            images.append(vis_img)
        except Exception as e:
            print(f"[ERROR] Error processing visualization for batch {b}: {str(e)}")
            continue

    final_image = concat_images(images)
    if final_image.size == 0:
        print("[WARNING] No valid images to display.")
        return

    img_name = split_extension(os.path.basename(img_path), suffix='-RISE')
    output_path = os.path.join(output_dir, img_name)
    os.makedirs(output_dir, exist_ok=True)
    cv2.imwrite(output_path, final_image)
    cv2_imshow(final_image)



# 実行スクリプト

def run_rise_on_yolov5():
    # パラメータの設定
    model_path = "/gdrive/MyDrive/Deep_learning/CorneAI_nagoya/yolo5_forcresco/weights/eye_nii_2202_onecaseoneimage2_doctorcompare_yolov5s_epoch200_batch16_89.8p/last.pt"  # YOLOv5の重みファイルのパス
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    img_size = 640  # YOLOv5の入力サイズ
    maskspath = 'yolov5_rise_masks.npy'  # RISEマスクの保存パス
    generate_new_masks = True  # マスクを再生成する場合はTrue

    # YOLOv5のオブジェクト検出インスタンスの作成
    detector = YOLOV5TorchObjectDetector(
        model_weight=model_path,
        device=device,
        img_size=(img_size, img_size),
        names=names,
        mode='eval',
        confidence=0.25,
        iou_thresh=0.45,
        agnostic_nms=False
    )

    # RISEインスタンスの作成（deviceを渡す）
    rise_explainer = RISEForYOLOv5(model=None, detector=detector,  device=device, gpu_batch=4)

    # マスクの生成またはロード
    if generate_new_masks or not os.path.isfile(maskspath):
        yolov5_model = attempt_load(model_path)
        yolov5_model.to(device)
        yolov5_model.eval()

        rise_explainer.model = yolov5_model  # RISEにYOLOv5モデルをセット
        rise_explainer.generate_masks(N=500, s=8, p1=0.1, img_size=(img_size, img_size), savepath=maskspath)
    else:
        yolov5_model = attempt_load(model_path)  # map_locationを使用
        yolov5_model.to(device)
        yolov5_model.eval()

        rise_explainer.model = yolov5_model  # RISEにYOLOv5モデルをセット
        rise_explainer.load_masks(maskspath)
        print('RISE masks loaded.')

    # テスト画像のパス
    img_path = "/gdrive/MyDrive/研究/進行中の研究/角膜スマートフォンAIプロジェクト/前原の240問/フォトスリット_serial/24.jpg"
    output_dir = "/content/output"  # 出力ディレクトリのパス

    # RISEとYOLOv5を統合したメイン処理の実行
    main_with_rise(img_path, detector, rise_explainer, names, output_dir)

if __name__ == '__main__':
    run_rise_on_yolov5()


In [None]:
model_path = "/gdrive/MyDrive/Deep_learning/CorneAI_nagoya/yolo5_forcresco/weights/eye_nii_2202_onecaseoneimage2_doctorcompare_yolov5s_epoch200_batch16_89.8p/last.pt"
img_path = "/gdrive/MyDrive/研究/進行中の研究/角膜スマートフォンAIプロジェクト/前原の240問/フォトスリット_serial/3.jpg"