In [39]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
import torchvision.models as models
import torch.nn as nn
import torch
from sklearn.model_selection import train_test_split


In [40]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import tqdm


class CatNotCatDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.cat_samples = []
        self.not_cat_samples = []

        for sub_dir in os.listdir(root_dir):
            class_path = os.path.join(root_dir, sub_dir)
            if os.path.isdir(class_path):
                label = 1 if sub_dir.lower() == 'cat' else 0
                target_list = self.cat_samples if label == 1 else self.not_cat_samples
                
                for img_file in os.listdir(class_path):
                    if img_file.lower().endswith(('png', 'jpg', 'jpeg')):
                        img_path = os.path.join(class_path, img_file)
                        target_list.append((img_path, label))
        
        # Combine cat and not cat samples
        self.samples = self.cat_samples + self.not_cat_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image, label

In [41]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

dataset = CatNotCatDataset(root_dir='one_vs_rest', transform=transform)

from torch.utils.data import DataLoader, random_split, Subset
def check_split_balance(indices, dataset):
    labels = [dataset.samples[i][1] for i in indices]
    print(f"Class distribution: {sum(labels)} cats, {len(labels) - sum(labels)} not cats out of {len(labels)} samples")

def get_dataloaders(dataset, test_size=0.2, random_state=42):
    # Split cat and not cat samples into training and validation sets separately
    cat_train, cat_val = train_test_split(dataset.cat_samples, test_size=0.2, random_state=random_state)
    not_cat_train, not_cat_val = train_test_split(dataset.not_cat_samples, test_size=0.2, random_state=random_state)
    
    # Combine the splits to get the final training and validation sets
    train_samples = cat_train + not_cat_train
    val_samples = cat_val + not_cat_val
    
    # Convert lists of samples into Subset objects
    train_indices = [dataset.samples.index(sample) for sample in train_samples]
    val_indices = [dataset.samples.index(sample) for sample in val_samples]
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    
    # Creating data loaders for training and validation
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)
    
    return train_loader, val_loader

In [42]:
class ResNet18CatNotCat(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet18CatNotCat, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, num_classes)
        
        self.feature_maps = []  # To store the outputs of convolutional layers
        self.hooks = []  # To store the hooks

        # Function to recursively register hooks on all convolutional layers
        def register_hooks(module):
            if isinstance(module, nn.Conv2d):
                self.hooks.append(module.register_forward_hook(self.hook_fn))
            elif isinstance(module, nn.Sequential) or isinstance(module, nn.ModuleList):
                for child in module.children():
                    register_hooks(child)

        # Register hooks to capture the outputs of all convolutional layers
        register_hooks(self.resnet18)

    def hook_fn(self, module, input, output):
        self.feature_maps.append(output)

    def forward(self, x):
        self.feature_maps = []  # Reset feature maps on each forward pass
        x = self.resnet18(x)
        return x, self.feature_maps  # Return both the final output and the feature maps

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()

In [43]:
import matplotlib.pyplot as plt

def visualize_feature_maps(feature_maps):
    # Assuming feature_maps is a list of tensors from the convolutional layers
    for layer, f_map in enumerate(feature_maps):
        layer += 1  # Start layer indexing at 1
        plt.figure(figsize=(20, 15))
        channels = f_map.shape[1]
        
        for i in range(channels):
            plt.subplot(channels // 8 + 1, 8, i + 1)  # Arrange plots in a grid
            plt.imshow(f_map[0, i].detach().cpu().numpy(), cmap='gray')
            plt.axis('off')
        
        plt.show()


In [48]:
def train_model(model, train_loader, val_loader, num_epochs=1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in tqdm.tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, _ = model(images)  # Assuming your model returns outputs and feature maps
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs, _ = model(images)  # Adjusted to only expect the final output
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print(f'Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}, Validation Loss: {val_loss / len(val_loader)}, Accuracy: {100 * correct / total}%')

    print('Finished Training')
    visualize_feature_maps(model.feature_maps)
    
num_folds = 3
num_epochs = 3

for fold in range(num_folds):
    print(f"Starting fold {fold+1}/{num_folds}")
    train_loader, val_loader = get_dataloaders(dataset, fold, num_folds)
    
    # Re-initialize the model for each fold
    model = ResNet18CatNotCat(num_classes=2)
    
    # Train the model on the current fold
    train_model(model, train_loader, val_loader, num_epochs)

Starting fold 1/3


100%|██████████| 15/15 [00:39<00:00,  2.64s/it]
 33%|███▎      | 1/3 [00:43<01:26, 43.48s/it]

Epoch 1, Loss: 0.3202224279443423, Validation Loss: 2.6313564777374268, Accuracy: 42.016806722689076%


100%|██████████| 15/15 [00:41<00:00,  2.78s/it]
 67%|██████▋   | 2/3 [01:29<00:44, 44.81s/it]

Epoch 2, Loss: 0.14376907444869477, Validation Loss: 0.2945615539792925, Accuracy: 91.59663865546219%


100%|██████████| 15/15 [00:42<00:00,  2.81s/it]
100%|██████████| 3/3 [02:15<00:00, 45.08s/it]


Epoch 3, Loss: 0.05751069579273462, Validation Loss: 0.1908994406403508, Accuracy: 98.31932773109244%
Finished Training
Starting fold 2/3


100%|██████████| 15/15 [00:41<00:00,  2.77s/it]
 33%|███▎      | 1/3 [00:45<01:31, 45.76s/it]

Epoch 1, Loss: 0.34428858359654746, Validation Loss: 1.5808865539729595, Accuracy: 88.23529411764706%


100%|██████████| 15/15 [00:42<00:00,  2.82s/it]
 67%|██████▋   | 2/3 [01:32<00:46, 46.12s/it]

Epoch 2, Loss: 0.1664735125998656, Validation Loss: 0.18247494287788868, Accuracy: 93.27731092436974%


100%|██████████| 15/15 [00:43<00:00,  2.91s/it]
100%|██████████| 3/3 [02:19<00:00, 46.63s/it]


Epoch 3, Loss: 0.04502687333151698, Validation Loss: 1.3819148242473602, Accuracy: 63.865546218487395%
Finished Training
Starting fold 3/3


100%|██████████| 15/15 [00:42<00:00,  2.82s/it]
 33%|███▎      | 1/3 [00:46<01:32, 46.35s/it]

Epoch 1, Loss: 0.44018866469462714, Validation Loss: 0.767303096174146, Accuracy: 89.91596638655462%


100%|██████████| 15/15 [00:44<00:00,  2.94s/it]
 67%|██████▋   | 2/3 [01:34<00:47, 47.32s/it]

Epoch 2, Loss: 0.2218310061842203, Validation Loss: 1.8402797589078546, Accuracy: 92.43697478991596%


100%|██████████| 15/15 [00:40<00:00,  2.69s/it]
100%|██████████| 3/3 [02:18<00:00, 46.28s/it]

Epoch 3, Loss: 0.13474627994000912, Validation Loss: 0.39895914820954204, Accuracy: 94.11764705882354%
Finished Training



