In [None]:
from model import UNet
import torch
import cv2
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt

: 

In [None]:
MODEL_PATH = 'latest_model_focal_ugly.pth'
N_CLASSES = 6
TARGET_SIZE = (3200//2, 2496//2)

In [None]:
def test_model():
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model
    model = UNet(N_CLASSES).to(device)
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()
    
    # Load and preprocess image
    img = cv2.imread(IMG_PATH)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, TARGET_SIZE)
    
    # Convert to tensor and normalize
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    # Get prediction
    with torch.no_grad():
        output = model(img_tensor)
        pred = torch.argmax(output[0], dim=0).cpu().numpy()
    
    # Visualize results
    class_colors = plt.cm.tab10(np.linspace(0, 1, N_CLASSES))
    colored_pred = class_colors[pred]
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(121)
    plt.imshow(img)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(122)
    plt.imshow(colored_pred[...,:3])  # Remove alpha channel
    plt.title('Segmentation Prediction')
    plt.axis('off')
    
    class_labels = ['Year', 'Date', 'Longitude', 'Latitude', 'Temperature', 'Background']
    patches = [plt.Rectangle((0,0),1,1, fc=class_colors[i][:3]) for i in range(N_CLASSES)]
    plt.legend(patches, class_labels, bbox_to_anchor=(1.05, 1))
    
    plt.tight_layout()
    plt.show(bbox_inches='tight')
    plt.close()

In [None]:
# IMG_PATH = '8013620831-0098.jpg-b.jpg'
# IMG_PATH = '8013620831-0187.jpg-t.jpg'
# IMG_PATH = '8013620831-0077.jpg-t.jpg'
IMG_PATH = '8013620831-0061.jpg-t.jpg'

In [None]:
test_model()

In [None]:
# Load the ground truth mask (if available)
ground_truth_path = "data/processed/masks/your_ground_truth_mask.png"  # Replace with the actual path
ground_truth = cv2.imread(ground_truth_path, cv2.IMREAD_GRAYSCALE)
ground_truth = cv2.resize(ground_truth, TARGET_SIZE)  # Resize to match model output

# Compute Dice score or other metrics
def compute_dice_score(pred, ground_truth):
    intersection = torch.sum(pred * ground_truth)
    return 2. * intersection / (torch.sum(pred) + torch.sum(ground_truth))

# Assuming binary segmentation (modify as necessary for multi-class)
dice_score = compute_dice_score(pred > 0, ground_truth > 0)  # Thresholding prediction and ground truth
print(f'Dice Score: {dice_score.item():.4f}')