In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Import our custom modules (assume they're in the same directory)
from notebook_1_data_prep_feature_extraction import MultiScaleFeatureExtractor
from notebook_2_hierarchical_transformer_implementation import HierarchicalTransformer
from notebook_3_training_and_evaluation import IntegratedModel

# Constants
NUM_CLASSES = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained model (assume we've trained and saved it)
def load_model():
    model = IntegratedModel().to(DEVICE)
    model.load_state_dict(torch.load('hierarchical_transformer_model.pth'))
    model.eval()
    return model

# Prepare a single image for input
def prepare_image(img_path):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = torchvision.io.read_image(img_path)
    return transform(img).unsqueeze(0).to(DEVICE)

# Visualize attention patterns
def visualize_attention(model, img):
    with torch.no_grad():
        features = model.feature_extractor(img)
        transformed_features = [
            transformer(feature.flatten(2).permute(2, 0, 1))
            for transformer, feature in zip(model.hierarchical_transformer.transformers, features)
        ]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    for i, feature in enumerate(transformed_features):
        attention_weights = model.hierarchical_transformer.transformers[i].self_attn.compute_attention_weights(feature, feature, feature)
        attention_map = attention_weights[0].mean(0).cpu().numpy()
        axes[i].imshow(attention_map, cmap='viridis')
        axes[i].set_title(f"Scale {i+1} Attention")
    plt.tight_layout()
    plt.show()

# Analyze model behavior on different image types
def analyze_image_types(model, dataloader):
    class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Analyzing"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    for i in range(NUM_CLASSES):
        print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')

# Experiment with different scale configurations
def experiment_scales(img, num_scales_list=[2, 3, 4]):
    results = []
    for num_scales in num_scales_list:
        model = HierarchicalTransformer(num_scales=num_scales, d_model=256, nhead=8, num_classes=NUM_CLASSES).to(DEVICE)
        feature_extractor = MultiScaleFeatureExtractor(num_scales=num_scales).to(DEVICE)
        
        with torch.no_grad():
            features = feature_extractor(img)
            output = model(features)
        
        results.append(output)
    
    fig, ax = plt.subplots(figsize=(10, 5))
    for i, result in enumerate(results):
        ax.bar(np.arange(NUM_CLASSES) + i*0.25, result[0].cpu().softmax(0), width=0.25, label=f'{num_scales_list[i]} scales')
    ax.set_xticks(np.arange(NUM_CLASSES))
    ax.set_xticklabels(('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'))
    ax.legend()
    plt.title('Model predictions with different numbers of scales')
    plt.show()

# Main execution
if __name__ == "__main__":
    model = load_model()
    
    # Visualize attention patterns
    img = prepare_image('sample_image.jpg')
    visualize_attention(model, img)
    
    # Analyze model behavior on different image types
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
    analyze_image_types(model, testloader)
    
    # Experiment with different scale configurations
    experiment_scales(img)

print("Visualization and analysis complete.")