In [7]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
import torchvision.models as models
import torch.nn as nn
import torch


In [8]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import tqdm


class CatNotCatDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the animal subdirectories.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # Separate handling for 'cat' and 'not cat' images
        for sub_dir in os.listdir(root_dir):
            class_path = os.path.join(root_dir, sub_dir)
            if os.path.isdir(class_path):
                # Label 1 for cat, 0 for not cat
                label = 1 if sub_dir.lower() == 'cat' else 0
                # If 'not cat', traverse further into subdirectories
                if label == 0:
                    for sub_class in os.listdir(class_path):
                        sub_class_path = os.path.join(class_path, sub_class)
                        if os.path.isdir(sub_class_path):
                            for img_file in os.listdir(sub_class_path):
                                if img_file.lower().endswith(('png', 'jpg', 'jpeg')):
                                    img_path = os.path.join(sub_class_path, img_file)
                                    self.samples.append((img_path, label))
                else:
                    for img_file in os.listdir(class_path):
                        if img_file.lower().endswith(('png', 'jpg', 'jpeg')):
                            img_path = os.path.join(class_path, img_file)
                            self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image, label


In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

dataset = CatNotCatDataset(root_dir='one_vs_rest', transform=transform)

from torch.utils.data import DataLoader, random_split, Subset

def get_dataloaders(dataset, fold, num_folds=3):
    total_size = len(dataset)
    fraction = 1.0 / num_folds
    seg = int(total_size * fraction)
    # Calculate the start and end indices for the current fold
    start, end = fold * seg, (fold + 1) * seg if fold < num_folds - 1 else total_size
    
    # Create indices for training and validation
    train_indices = list(range(0, start)) + list(range(end, total_size))
    val_indices = list(range(start, end))
    
    # Creating PyTorch datasets for training and validation
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    
    # Creating data loaders for training and validation
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)
    
    return train_loader, val_loader

In [10]:
class ResNet18CatNotCat(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet18CatNotCat, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, num_classes)
        
        self.feature_maps = []  # To store the outputs of convolutional layers
        self.hooks = []  # To store the hooks

        # Function to recursively register hooks on all convolutional layers
        def register_hooks(module):
            if isinstance(module, nn.Conv2d):
                self.hooks.append(module.register_forward_hook(self.hook_fn))
            elif isinstance(module, nn.Sequential) or isinstance(module, nn.ModuleList):
                for child in module.children():
                    register_hooks(child)

        # Register hooks to capture the outputs of all convolutional layers
        register_hooks(self.resnet18)

    def hook_fn(self, module, input, output):
        self.feature_maps.append(output)

    def forward(self, x):
        self.feature_maps = []  # Reset feature maps on each forward pass
        x = self.resnet18(x)
        return x, self.feature_maps  # Return both the final output and the feature maps

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()

In [11]:
import matplotlib.pyplot as plt

def visualize_feature_maps(feature_maps):
    # Assuming feature_maps is a list of tensors from the convolutional layers
    for layer, f_map in enumerate(feature_maps):
        layer += 1  # Start layer indexing at 1
        plt.figure(figsize=(20, 15))
        channels = f_map.shape[1]
        
        for i in range(channels):
            plt.subplot(channels // 8 + 1, 8, i + 1)  # Arrange plots in a grid
            plt.imshow(f_map[0, i].detach().cpu().numpy(), cmap='gray')
            plt.axis('off')
        
        plt.show()


In [12]:
def train_model(model, train_loader, val_loader, num_epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in tqdm.tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, _ = model(images)  # Assuming your model returns outputs and feature maps
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs, _ = model(images)  # Adjusted to only expect the final output
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print(f'Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}, Validation Loss: {val_loss / len(val_loader)}, Accuracy: {100 * correct / total}%')

    print('Finished Training')
    
num_folds = 3
num_epochs = 3

for fold in range(num_folds):
    print(f"Starting fold {fold+1}/{num_folds}")
    train_loader, val_loader = get_dataloaders(dataset, fold, num_folds)
    
    # Re-initialize the model for each fold
    model = ResNet18CatNotCat(num_classes=2)
    
    # Train the model on the current fold
    train_model(model, train_loader, val_loader, num_epochs)

Starting fold 1/3


100%|██████████| 2/2 [00:03<00:00,  1.63s/it]
 33%|███▎      | 1/3 [00:03<00:07,  3.96s/it]

Epoch 1, Loss: 1.508114993572235, Validation Loss: 1.5895168781280518, Accuracy: 10.0%


100%|██████████| 2/2 [00:03<00:00,  1.67s/it]
 67%|██████▋   | 2/3 [00:07<00:04,  4.00s/it]

Epoch 2, Loss: 0.4756036549806595, Validation Loss: 0.31816595792770386, Accuracy: 80.0%


100%|██████████| 2/2 [00:03<00:00,  1.58s/it]
100%|██████████| 3/3 [00:11<00:00,  3.91s/it]


Epoch 3, Loss: 0.1185636855661869, Validation Loss: 0.018176505342125893, Accuracy: 100.0%
Finished Training
Starting fold 2/3


100%|██████████| 2/2 [00:03<00:00,  1.57s/it]
 33%|███▎      | 1/3 [00:03<00:07,  3.73s/it]

Epoch 1, Loss: 1.2314414978027344, Validation Loss: 0.9124992489814758, Accuracy: 30.0%


100%|██████████| 2/2 [00:03<00:00,  1.65s/it]
 67%|██████▋   | 2/3 [00:07<00:03,  3.83s/it]

Epoch 2, Loss: 0.374253049492836, Validation Loss: 3.3409228324890137, Accuracy: 35.0%


100%|██████████| 2/2 [00:03<00:00,  1.58s/it]
100%|██████████| 3/3 [00:11<00:00,  3.79s/it]


Epoch 3, Loss: 0.17251741886138916, Validation Loss: 0.02251262031495571, Accuracy: 100.0%
Finished Training
Starting fold 3/3


100%|██████████| 2/2 [00:03<00:00,  1.62s/it]
 33%|███▎      | 1/3 [00:03<00:07,  3.78s/it]

Epoch 1, Loss: 0.5021655857563019, Validation Loss: 0.0849669873714447, Accuracy: 95.0%


100%|██████████| 2/2 [00:03<00:00,  1.57s/it]
 67%|██████▋   | 2/3 [00:07<00:03,  3.73s/it]

Epoch 2, Loss: 0.3717193379998207, Validation Loss: 4.082569122314453, Accuracy: 0.0%


100%|██████████| 2/2 [00:03<00:00,  1.62s/it]
100%|██████████| 3/3 [00:11<00:00,  3.76s/it]

Epoch 3, Loss: 0.19213799387216568, Validation Loss: 2.105037212371826, Accuracy: 10.0%
Finished Training





In [13]:
def visualize_feature_maps(feature_maps):
    for layer, f_map in enumerate(feature_maps):
        layer += 1  # Start layer indexing at 1
        plt.figure(figsize=(20, 15))
        channels = f_map.shape[1]
        
        for i in range(min(channels, 64)):  # Visualize up to 64 channels to keep the output manageable
            plt.subplot(min(channels // 8 + 1, 8), 8, i + 1)  # Arrange plots in a grid
            plt.imshow(f_map[0, i].detach().cpu().numpy(), cmap='gray')
            plt.axis('off')
        
        plt.show()