<a href="https://colab.research.google.com/github/vmazashvili/Neural-Networks/blob/main/MaskTune_Vano_Mazashvili_1993251.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MaskTune Model

- train an ERM Model (Empirical Risk Minimization)
- Generate Masks based on Model Outputs
- Fine-tune the Model with Masks
- Testing and Evaluation

In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn
from torch import optim


Small Convolutional Neural Network definition

In [3]:
class SmallCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.backbone = nn.Sequential(
            nn.Conv2d(1, 16, (3, 3), (1, 1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, (3, 3), (1, 1)),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), (2, 2)),
            nn.Conv2d(16, 32, (3, 3), (1, 1)),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.Conv2d(32, 32, (3, 3), (1, 1)),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), (2, 2)),
            nn.Flatten()
        )

        self.linear = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(256, self.num_classes)
        )

    def get_grad_cam_target_layer(self):
        return self.backbone[-3]

    def forward(self, x):
        features = self.backbone(x)
        logits = self.linear(features)

        return logits


passing parameters to CNN defined above and downloading, preparing and training MNIST

In [4]:

# Define hyperparameters
num_epochs = 10  # Number of training epochs
batch_size = 64  # Batch size for training and testing
learning_rate = 0.001  # Learning rate for the optimizer

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert data to tensors
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize data based on MNIST statistics
])

# Load MNIST datasets
train_data = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Define the model
model = SmallCNN(num_classes=10)  # Assuming 10 classes for MNIST

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    # Training loop
    for data, target in train_loader:
        # Forward pass
        output = model(data)
        loss = criterion(output, target)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print training progress (optional)
    print(f"Epoch: {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for data, target in test_loader:
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

    # Print test accuracy
    print(f"Test Accuracy: {correct / total:.4f}")



'''
Sources

github.com/greenkarson/python
www.analyticsvidhya.com/blog/2021/06/autoencoders-a-gentle-introduction/
github.com/Alexrich961210/Graduation-Project
github.com/Su-Bi-su/DeepNeuralNetworks
github.com/RBhupi/NN_TSI_Raindrop
'''

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [04:32<00:00, 36420.08it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 416029.39it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:49<00:00, 33602.38it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4678420.62it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Epoch: 1/10, Loss: 0.0377
Epoch: 2/10, Loss: 0.0048
Epoch: 3/10, Loss: 0.0008
Epoch: 4/10, Loss: 0.0012
Epoch: 5/10, Loss: 0.0393
Epoch: 6/10, Loss: 0.0022
Epoch: 7/10, Loss: 0.0525
Epoch: 8/10, Loss: 0.0015
Epoch: 9/10, Loss: 0.0249
Epoch: 10/10, Loss: 0.0093
Test Accuracy: 0.9921
