In [None]:
from model import SegmentationModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = SegmentationModel()
model.to(device)

print(model)

In [None]:
import torchvision.transforms as T
from PIL import Image
from config import imageSize

img_path = "./dataset/train2014/COCO_train2014_000000000009.jpg"


img = Image.open(img_path).convert("RGB")
transform = T.Compose(
    [
        T.Resize(imageSize),
        T.ToTensor(),
    ]
)
img_to_tensor = transform(img).unsqueeze(0).to(device)


In [None]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

model = SegmentationModel().to(device)
model.eval()

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
             record_shapes=True) as prof:
    with record_function("model_inference"):
        with torch.no_grad():
            _ = model(img_to_tensor)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))


In [None]:
torch.onnx.export(model, img_to_tensor, "Segmentation.onnx")
# View .onnx in Netron: https://netron.app

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    output = model(img_to_tensor)
    output_img = output.squeeze().cpu().numpy()

plt.imshow(output_img, cmap="gray")
plt.title("Model Output")
plt.axis("off")
plt.show()

In [None]:
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

def visualize_conv_filters_notebook(model):
    """
    Notebook-friendly visualization of convolutional filters
    """
    conv_layers = [module for module in model.modules() if isinstance(module, nn.Conv2d)]
    
    print(f"Found {len(conv_layers)} convolutional layers\n")
    
    for layer_idx, layer in enumerate(conv_layers):
        filters = layer.weight.data.cpu().numpy()
        
        n_filters = filters.shape[0]
        n_channels = filters.shape[1]
        
        n_cols = min(8, n_filters)
        n_rows = (n_filters + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 1.5, n_rows * 1.5))
        fig.suptitle(f'Layer {layer_idx}: Conv2d | Out: {filters.shape[0]} | In: {n_channels} | Kernel: {filters.shape[2]}x{filters.shape[3]}', 
                     fontsize=12, fontweight='bold')
        
        if n_rows == 1 and n_cols == 1:
            axes = [axes]
        else:
            axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes.flatten()
        
        for i in range(n_filters):
            ax = axes[i]
            
            if n_channels == 1:
                filter_img = filters[i, 0]
                cmap = 'gray'
            elif n_channels == 3:
                filter_img = np.transpose(filters[i], (1, 2, 0))
                filter_img = (filter_img - filter_img.min()) / (filter_img.max() - filter_img.min() + 1e-8)
                cmap = None
            else:
                filter_img = filters[i].mean(axis=0)
                cmap = 'viridis'
            
            if cmap:
                vmin, vmax = filter_img.min(), filter_img.max()
                if vmax - vmin > 1e-8:
                    filter_img = (filter_img - vmin) / (vmax - vmin)
            
            ax.imshow(filter_img, cmap=cmap)
            ax.axis('off')
        
        for i in range(n_filters, len(axes)):
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()



def layer_summary(model):
    """
    Quick summary of all conv layers
    """
    conv_layers = [module for module in model.modules() if isinstance(module, nn.Conv2d)]
    
    print(f"{'='*70}")
    print(f"Total Convolutional Layers: {len(conv_layers)}")
    print(f"{'='*70}")
    print(f"{'Layer':<6} | {'Out Ch':<7} | {'In Ch':<6} | {'Kernel':<8} | {'Filters':<8}")
    print(f"{'-'*70}")
    
    for i, layer in enumerate(conv_layers):
        w = layer.weight.data
        print(f"{i:<6} | {w.shape[0]:<7} | {w.shape[1]:<6} | {w.shape[2]}x{w.shape[3]:<6} | {w.shape[0]:<8}")
    
    print(f"{'='*70}\n")
    
layer_summary(model)
visualize_conv_filters_notebook(model, )
