In [5]:
import wandb
import torch
from model import CNNModel
import torch.nn.functional as F
from collections import defaultdict
from torchvision.utils import make_grid
from datamodule import iNaturalistDataModule
from torchvision.transforms.functional import to_pil_image


In [2]:
ckpt_path = "./checkpoints/train_run_LR_0.0001_DATAUG:True_FILTERS:64_FILTERSIZE:5_FILTERORG:double_ACTIVATION:Mish_BATCHNORM:True_DROPOUT:0.3_DENSE:32.ckpt"
model = CNNModel.load_from_checkpoint(ckpt_path)

In [3]:
data_module = iNaturalistDataModule(
    image_dim=224,
    val_split=0.2,
    data_augmentation=True,
    batch_size=256,
    num_workers=32,
)
data_module.prepare_data()
data_module.setup()

In [6]:
class_images = defaultdict(list)

# Iterate through the test loader to accumulate images
test_loader = data_module.test_dataloader()
for batch_images, batch_labels in test_loader:
    for img, label in zip(batch_images, batch_labels):
        if len(class_images[label.item()]) < 3:  # Only add if we need more images for this class
            class_images[label.item()].append(img)
    
    # Check if we have 3 images for each class
    if all(len(class_images[class_idx]) >= 3 for class_idx in range(10)):  # Assuming 10 classes
        break

# Select exactly 3 images for each class
selected_images = []
selected_labels = []
for class_idx in range(10):  # Assuming there are 10 classes
    selected = class_images[class_idx][:3]  # Take the first 3 images for the class
    selected_images.extend(selected)
    selected_labels.extend([class_idx] * len(selected))

# Convert the selected images and labels to tensors
images = torch.stack(selected_images)
labels = torch.tensor(selected_labels)

In [7]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
images = images.to(device)
model = model.to(device)

In [8]:
class_to_idx = data_module.test_dataset.class_to_idx  # Access the original dataset
idx_to_class = {v: k for k, v in class_to_idx.items()}  # Reverse the mapping to get index-to-class

In [9]:
def denormalize(img_tensor, mean, std, device):
    mean = torch.tensor(mean, device=device).view(1, 3, 1, 1)  # Reshape for broadcasting and move to device
    std = torch.tensor(std, device=device).view(1, 3, 1, 1)    # Reshape for broadcasting and move to device
    return img_tensor * std + mean

# Reverse normalization before converting to PIL image
mean = [0.5, 0.5, 0.5]
std = [0.2, 0.2, 0.2]

# Perform inference
model.eval()
with torch.no_grad():
    outputs = model(images)
    predictions = torch.argmax(F.softmax(outputs, dim=1), dim=1)

In [15]:
# Create a grid of images with predictions
grid = []
for img, label, pred in zip(images, labels, predictions):
    img = denormalize(img.unsqueeze(0), mean, std, device).squeeze(0)  # Denormalize the image
    img = to_pil_image(img.cpu().clamp(0, 1))  # Convert tensor to PIL image and clamp values to [0, 1]
    true_class = idx_to_class[label.item()]  # Get the true class name
    pred_class = idx_to_class[pred.item()]  # Get the predicted class name
    caption = f"True: {true_class}, Pred: {pred_class}"
    grid.append(wandb.Image(img, caption=caption))


In [16]:
wandb.init(project="DA6401_A2", name="test_images_predictions")

In [17]:
wandb.log({"Sample Test Predictions": grid})

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Trainable Parameters: {count_parameters(model)}")

Total Trainable Parameters: 4710250


In [None]:
from ptflops import get_model_complexity_info

# If your LightningModule wraps the actual model inside `self.model`, change this accordingly
with torch.no_grad():
    macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
    print(f"MACs: {macs}")
    print(f"Parameters: {params}")

MACs: 7.01 GMac
Parameters: 4.71 M


In [18]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.
