In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model.segformer_simple import Segformerwithcarbon
from model.unet import UNet_carbon
from dataset_segwithcarbon import CarbonDataset, CarbonDataset_csv
from model.util import select_device
from tqdm import tqdm
from model.metrics import CarbonLoss , CarbonLossWithRMSE
import os

def validate(model, device, loss_fn, loader, domain="Forest"):
    model.eval()
    val_stats = {'total_loss': 0.0, 'total_cls_loss': 0.0, 'total_reg_loss': 0.0,
                 'total_acc_c': 0.0, 'total_acc_r': 0.0, 'total_miou': 0.0, 'total_rmse': 0.0, 'batches': 0}
    
    with torch.no_grad():
        for x, carbon, gt in tqdm(loader, desc=f"Validating {domain}"):
            x, carbon, gt = x.to(device), carbon.to(device), gt.to(device)
            gt_pred, carbon_pred = model(x)
            total_loss, cls_loss, reg_loss, acc_c, acc_r, miou, rmse = loss_fn(gt_pred, gt.squeeze(1), carbon_pred, carbon)
            
            val_stats['total_loss'] += total_loss.item()
            val_stats['total_cls_loss'] += cls_loss.item()
            val_stats['total_reg_loss'] += reg_loss.item()
            val_stats['total_acc_c'] += acc_c
            val_stats['total_acc_r'] += acc_r
            val_stats['total_miou'] += miou
            val_stats['total_rmse'] += rmse
            val_stats['batches'] += 1

    return {k: v / val_stats['batches'] for k, v in val_stats.items() if k != 'batches'}

def get_transforms(label_size):
    image_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    sh_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    label_transform = transforms.Compose([
        transforms.Resize((label_size, label_size)),
    ])
    return image_transform, sh_transform, label_transform

device = select_device()

In [None]:
image_transform, sh_transform, label_transform = get_transforms(512)
forest_fp = "val_AP25_Forest_IMAGE.csv"
city_fp = "val_AP25_City_IMAGE.csv"

FOLDER_PATH = {
    'AP10_Forest_IMAGE.csv': 4,
    'AP25_Forest_IMAGE.csv': 4,   
    'AP10_City_IMAGE.csv': 4,
    'AP25_City_IMAGE.csv': 4,
    'SN10_Forest_IMAGE.csv': 4,
}
def remove_profile_keys(state_dict):
    return {k: v for k, v in state_dict.items() if not k.endswith(('total_ops', 'total_params'))}

# 상태 사전 로드
state_dict = torch.load('checkpoints/Unet/UnetUnetAP25_Forest_256_best.pth', map_location=torch.device('cpu'))

# 프로파일링 키 제거
clean_state_dict = remove_profile_keys(state_dict)

batch_size = 1
clss_num = 4

forest_dataset = CarbonDataset_csv(forest_fp, image_transform, sh_transform, label_transform, mode="Valid",combine=True)
forest_loader = DataLoader(forest_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

city_dataset = CarbonDataset_csv(city_fp, image_transform, sh_transform, label_transform, mode="Valid",combine=True)
city_loader = DataLoader(city_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)


In [None]:

model = UNet_carbon(num_classes=clss_num)
# 업데이트된 상태 사전을 모델에 로드
# 정제된 상태 사전 로드
model.load_state_dict(clean_state_dict, strict=True)

model = model.to(device)
loss_fn = CarbonLossWithRMSE(num_classes=clss_num, cls_lambda=1, reg_lambda=0.0005).to(device)
