# --- PyTorch Fully Convolutional Network Inference ---

In [1]:
import torch
from torchvision import transforms
from matplotlib import pyplot as plt
import numpy as np
from trainer.trainer3 import LandCoverModel

In [2]:
def visualize_results(images, predictions, labels, num_samples):
    """
    Helper function to visualize the results of the model predictions.
    """
    plt.figure(figsize=(10, 2 * num_samples))
    for idx in range(num_samples):
        plt.subplot(num_samples, 3, 3*idx+1)
        plt.imshow(np.transpose(images[idx], (1, 2, 0)))
        plt.title("Input Image")
        plt.axis('off')
        
        plt.subplot(num_samples, 3, 3*idx+2)
        plt.imshow(predictions[idx].transpose(1, 2, 0))
        plt.title("Model Prediction")
        plt.axis('off')
        
        plt.subplot(num_samples, 3, 3*idx+3)
        plt.imshow(labels[idx].transpose(1, 2, 0))
        plt.title("Ground Truth")
        plt.axis('off')
        
    plt.tight_layout()
    plt.show()

In [3]:
def model_inference_and_visualization(model, dataloader, device, num_samples=5):
    """
    Perform inference on the data provided by the dataloader and visualize the results.
    """
    model.eval()  # Set the model to evaluation mode
    model.to(device)
    
    images, predictions, labels = [], [], []
    with torch.no_grad():  # Turn off gradients to speed up this part
        for i, (inputs, labels_batch) in enumerate(dataloader):
            if i >= num_samples:  # Limit the number of samples for visualization
                break
            inputs = inputs.to(device)
            labels_batch = labels_batch.to(device)
            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)  # Apply sigmoid to convert logits to probabilities

            # Store results for visualization
            images.extend(inputs.cpu().numpy())
            predictions.extend(outputs.cpu().numpy())
            labels.extend(labels_batch.cpu().numpy())

    # Visualize the results
    visualize_results(images, predictions, labels, num_samples)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [8]:
model = LandCoverModel().to(device)
model.load_state_dict(torch.load('trainer/checkpoints/best-checkpoint.ckpt', map_location=torch.device('cpu'))['state_dict'])

<All keys matched successfully>

In [9]:
model

LandCoverModel(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(32, 16, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(16, 9, kernel_size=(1, 1), stride=(1, 1))
  )
  (loss_fn): BCEWithLogitsLoss()
)

In [10]:
test_loader = model.test_dataloader()

AttributeError: 'LandCoverModel' object has no attribute 'test_dataset'

In [None]:
model_inference_and_visualization(model, test_loader, device, num_samples=5)