In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import cv2
import numpy as np
import os

class BackboneNet(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.feature_extractor = torch.nn.Sequential(*list(model.children())[:-2])
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(2048, 7)  # 7개의 클래스
        )

    def forward(self, x):
        cam_features = self.feature_extractor(x)
        features = self.gap(cam_features)
        features = features.view(x.shape[0], -1)
        logits = self.classifier(features)
        return logits, cam_features


from upsamplers import get_upsampler
upsampler = get_upsampler('jbu_stack', dim=2048).cuda()

def preprocess_image(image_path, zoom_save_path):
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert("RGB")


    return preprocess(image).unsqueeze(0).to("cuda")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_cam(feature_map, fc_weights, class_idx):
    
    feature_map = feature_map.squeeze(0)  # Remove batch dimension
    fc_weight = fc_weights[class_idx].to(feature_map.device)  # Ensure weights are on the same device as feature_map
    cam = torch.matmul(fc_weight, feature_map.view(feature_map.size(0), -1))
    cam = cam.view(feature_map.size(1), feature_map.size(2))  # Reshape to H x W
    cam = F.relu(cam)  # Apply ReLU
    cam -= cam.min()
    cam /= cam.max()
    return cam

def save_overlay_feature_map(image_path, cam, output_path1, output_path2, alpha=0.5):

    original_image = cv2.imread(image_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    height, width, _ = original_image.shape

    cam = cam.detach().cpu().numpy()
    cam_resized = cv2.resize(cam, (224, 224))
    cam_resized2 = (cam_resized > 0.5).astype(np.uint8) * 255
    cam_resized = np.uint8(255 * cam_resized)
    heatmap = cv2.applyColorMap(cam_resized, cv2.COLORMAP_JET)


    overlay = cv2.addWeighted(original_image, 1 - alpha, heatmap, alpha, 0)


    cv2.imwrite(output_path1, cam_resized2)
    cv2.imwrite(output_path2, overlay)

def process_directory_with_cam(input_dir, output_dir1, output_dir2):
    os.makedirs(output_dir1, exist_ok=True)
    os.makedirs(output_dir2, exist_ok=True)

    backbone = BackboneNet().to(device).eval()
    backbone.load_state_dict(torch.load('best_model.pth', map_location=device))

    for image_file in os.listdir(input_dir):
        image_path = os.path.join(input_dir, image_file)
        output_path1 = os.path.join(output_dir1, f"{image_file}")
        output_path2 = os.path.join(output_dir2, f"{image_file}")

        image = preprocess_image(image_path, zoom_save_path)

        with torch.no_grad():
            logits, feature_map = backbone(image)
            class_idx = torch.argmax(logits, dim=1).item()
            fc_weights = backbone.classifier[0].weight  # Ensure this is on the correct device


        with torch.no_grad():
            hr_cam = upsampler(feature_map, image)  # Use upsampler to enhance CAM resolution
            hr_cam = generate_cam(hr_cam, fc_weights, class_idx)  # Regenerate CAM after upsampling


        save_overlay_feature_map(image_path, hr_cam, output_path1, output_path2, alpha=0.5)


input_dir = ""
output_dir1 = ""
output_dir2 = ""

# 전체 디렉토리 처리
process_directory_with_cam(input_dir, output_dir1, output_dir2)
