In [None]:
import torch 
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F
import torchvision 
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt 
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='126156009/17flowerdataset.zip', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='126156009/17flowerdataset.zip', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(model, trainloader, criterion, optimizer, epochs=3):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)  
            
            print(f"Logits for first image in batch: {outputs[0]}")  

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.4f}")

train(model, trainloader, criterion, optimizer)


In [None]:
def extract_features(model, dataloader):
    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for images, lbls in dataloader:
            images = images.to(device)
            outputs = model(images)  

            print(f"Logits for first image in batch: {outputs[0]}")  

            features.extend(outputs.cpu().numpy())  
            labels.extend(lbls.numpy())
    return np.array(features), np.array(labels)

X_train, y_train = extract_features(model, trainloader)
X_test, y_test = extract_features(model, testloader)

In [None]:
dt = DecisionTreeClassifier(max_depth=5)  
dt.fit(X_train, y_train)


acc = dt.score(X_test, y_test)
print(f"Surrogate Model Accuracy: {acc * 100:.2f}%")


In [None]:

def visualize_surrogate_model(dt):
    plt.figure(figsize=(20, 16))
    plot_tree(dt, filled=True, feature_names=[f"Feature {i}" for i in range(X_train.shape[1])], class_names=[str(i) for i in range(10)], rounded=True,fontsize=14)
    plt.title("Surrogate Model - Decision Tree")
    plt.show()

visualize_surrogate_model(dt)


def plot_feature_importance(dt, feature_names):
    feature_importances = dt.feature_importances_
    indices = np.argsort(feature_importances)[::-1]

    plt.figure(figsize=(10, 6))
    plt.title("Feature Importances (Surrogate Model)")
    plt.barh(range(X_train.shape[1]), feature_importances[indices], align="center")
    plt.yticks(range(X_train.shape[1]), [f"Feature {i}" for i in indices])
    plt.xlabel("Importance")
    plt.show()

plot_feature_importance(dt, [f"Feature {i}" for i in range(X_train.shape[1])])


def visualize_feature_maps(model, input_image):
    model.eval()
    layers = [model.conv1, model.conv2]
    activations = []

    def save_activation(name):
        def hook(model, input, output):
            activations.append(output)
        return hook

   
    hooks = []
    for layer in layers:
        hooks.append(layer.register_forward_hook(save_activation(layer.__class__.__name__)))

   
    input_image = input_image.unsqueeze(0).to(device)
    model(input_image)

   
    for i, activation in enumerate(activations):
        activation = activation.squeeze(0).cpu().detach().numpy()
        num_filters = activation.shape[0]

        
        fig, axes = plt.subplots(1, num_filters, figsize=(15, 8))
        for j in range(num_filters):
            axes[j].imshow(activation[j], cmap='gray')
            axes[j].axis('off')
            axes[j].set_title(f"Filter {j + 1}")
        plt.show()

   
    for hook in hooks:
        hook.remove()


sample_image, sample_label = testset[0]
visualize_feature_maps(model, sample_image)