<a href="https://colab.research.google.com/github/sanjanb/BiasNet-Pretrained-Model/blob/main/Unsupervised_Pretraining_and_Binary_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch torchvision matplotlib scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
pip install torchsummary



# **Building the Autoencoder**

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download MNIST
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 20.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.69MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.09MB/s]


## **Define the Autoencoder Architecture**

In [None]:
import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28),
            nn.Tanh(),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [None]:
from torchsummary import summary

model = Autoencoder()
summary(model, input_size=(1, 28, 28))

model = Autoencoder()
print(model)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 784]               0
            Linear-2                  [-1, 128]         100,480
              ReLU-3                  [-1, 128]               0
            Linear-4                   [-1, 64]           8,256
              ReLU-5                   [-1, 64]               0
            Linear-6                  [-1, 128]           8,320
              ReLU-7                  [-1, 128]               0
            Linear-8                  [-1, 784]         101,136
              Tanh-9                  [-1, 784]               0
        Unflatten-10            [-1, 1, 28, 28]               0
Total params: 218,192
Trainable params: 218,192
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.03
Params size (MB): 0.83
Estimated T

## **Pretraining the Autoencoder**

In [None]:
import torch.optim as optim
from tqdm import tqdm

autoencoder = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder.to(device)

# Early stopping logic
best_loss = float('inf')
patience = 5
counter = 0

for epoch in range(50):
    autoencoder.train()
    running_loss = 0.0

    for inputs, _ in tqdm(train_loader):
        inputs = inputs.to(device)
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}')

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        counter = 0
        torch.save(autoencoder.state_dict(), 'pretrained_autoencoder.pth')
        print("Model saved.")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping.")
            break


100%|██████████| 938/938 [00:20<00:00, 45.97it/s]


Epoch 1, Loss: 0.1091
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.21it/s]


Epoch 2, Loss: 0.0526
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.43it/s]


Epoch 3, Loss: 0.0430
Model saved.


100%|██████████| 938/938 [00:21<00:00, 42.88it/s]


Epoch 4, Loss: 0.0382
Model saved.


100%|██████████| 938/938 [00:23<00:00, 40.72it/s]


Epoch 5, Loss: 0.0356
Model saved.


100%|██████████| 938/938 [00:21<00:00, 42.68it/s]


Epoch 6, Loss: 0.0333
Model saved.


100%|██████████| 938/938 [00:22<00:00, 42.29it/s]


Epoch 7, Loss: 0.0318
Model saved.


100%|██████████| 938/938 [00:21<00:00, 43.84it/s]


Epoch 8, Loss: 0.0306
Model saved.


100%|██████████| 938/938 [00:20<00:00, 44.92it/s]


Epoch 9, Loss: 0.0296
Model saved.


100%|██████████| 938/938 [00:21<00:00, 42.69it/s]


Epoch 10, Loss: 0.0289
Model saved.


100%|██████████| 938/938 [00:26<00:00, 34.79it/s]


Epoch 11, Loss: 0.0283
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.46it/s]


Epoch 12, Loss: 0.0278
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.12it/s]


Epoch 13, Loss: 0.0272
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.04it/s]


Epoch 14, Loss: 0.0268
Model saved.


100%|██████████| 938/938 [00:22<00:00, 42.42it/s]


Epoch 15, Loss: 0.0264
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.92it/s]


Epoch 16, Loss: 0.0261
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.32it/s]


Epoch 17, Loss: 0.0258
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.38it/s]


Epoch 18, Loss: 0.0255
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.99it/s]


Epoch 19, Loss: 0.0253
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.19it/s]


Epoch 20, Loss: 0.0251
Model saved.


100%|██████████| 938/938 [00:20<00:00, 44.75it/s]


Epoch 21, Loss: 0.0250
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.41it/s]


Epoch 22, Loss: 0.0248
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.79it/s]


Epoch 23, Loss: 0.0247
Model saved.


100%|██████████| 938/938 [00:20<00:00, 44.78it/s]


Epoch 24, Loss: 0.0246
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.00it/s]


Epoch 25, Loss: 0.0244
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.44it/s]


Epoch 26, Loss: 0.0243
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.00it/s]


Epoch 27, Loss: 0.0242
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.51it/s]


Epoch 28, Loss: 0.0240
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.27it/s]


Epoch 29, Loss: 0.0239
Model saved.


100%|██████████| 938/938 [00:20<00:00, 44.91it/s]


Epoch 30, Loss: 0.0239
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.35it/s]


Epoch 31, Loss: 0.0238
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.49it/s]


Epoch 32, Loss: 0.0236
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.02it/s]


Epoch 33, Loss: 0.0236
Model saved.


100%|██████████| 938/938 [00:23<00:00, 40.15it/s]


Epoch 34, Loss: 0.0235
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.99it/s]


Epoch 35, Loss: 0.0234
Model saved.


100%|██████████| 938/938 [00:21<00:00, 44.26it/s]


Epoch 36, Loss: 0.0233
Model saved.


100%|██████████| 938/938 [00:21<00:00, 44.22it/s]


Epoch 37, Loss: 0.0232
Model saved.


100%|██████████| 938/938 [00:20<00:00, 46.09it/s]


Epoch 38, Loss: 0.0231
Model saved.


100%|██████████| 938/938 [00:21<00:00, 44.29it/s]


Epoch 39, Loss: 0.0230
Model saved.


100%|██████████| 938/938 [00:21<00:00, 44.11it/s]


Epoch 40, Loss: 0.0229
Model saved.


100%|██████████| 938/938 [00:20<00:00, 45.39it/s]


Epoch 41, Loss: 0.0229
Model saved.


100%|██████████| 938/938 [00:21<00:00, 44.10it/s]


Epoch 42, Loss: 0.0227
Model saved.


100%|██████████| 938/938 [00:21<00:00, 43.53it/s]


Epoch 43, Loss: 0.0227
Model saved.


100%|██████████| 938/938 [00:21<00:00, 43.33it/s]


Epoch 44, Loss: 0.0226
Model saved.


100%|██████████| 938/938 [00:21<00:00, 44.24it/s]


Epoch 45, Loss: 0.0225
Model saved.


100%|██████████| 938/938 [00:22<00:00, 42.50it/s]


Epoch 46, Loss: 0.0224
Model saved.


100%|██████████| 938/938 [00:22<00:00, 42.33it/s]


Epoch 47, Loss: 0.0224
Model saved.


100%|██████████| 938/938 [00:22<00:00, 41.74it/s]


Epoch 48, Loss: 0.0224
Model saved.


100%|██████████| 938/938 [00:22<00:00, 42.00it/s]


Epoch 49, Loss: 0.0223
Model saved.


100%|██████████| 938/938 [00:22<00:00, 41.86it/s]

Epoch 50, Loss: 0.0222
Model saved.





# **Fine-Tuning with Binary Classification**

## **Dataset with Even/Odd Labels**

In [None]:
# Binary labels: 0 for Even, 1 for Odd
def binary_label(label):
    return 0 if label % 2 == 0 else 1

class BinaryMNIST(torch.utils.data.Dataset):
    def __init__(self, original_dataset):
        self.data = original_dataset.data
        self.targets = original_dataset.targets.apply_(binary_label).long()
        original_transforms = original_dataset.transform.transforms
        self.transform = transforms.Compose([t for t in original_transforms if not isinstance(t, transforms.ToTensor)])

    def __getitem__(self, idx):
        img = self.data[idx]
        img = img.float().unsqueeze(0)
        img = self.transform(img)
        label = self.targets[idx]
        return img, label

    def __len__(self):
        return len(self.targets)

binary_train = BinaryMNIST(mnist_train)
binary_test = BinaryMNIST(mnist_test)

train_loader_bin = DataLoader(binary_train, batch_size=64, shuffle=True)
test_loader_bin = DataLoader(binary_test, batch_size=64, shuffle=False)

## **Freeze Encoder and Add Classifier Head**

In [None]:
class BinaryClassifier(nn.Module):
    def __init__(self, pretrained_encoder):
        super(BinaryClassifier, self).__init__()
        self.encoder = pretrained_encoder
        for param in self.encoder.parameters():
            param.requires_grad = False  # Freeze
        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.encoder(x)
        out = self.classifier(features)
        return out


## **Training Binary Classifier**

In [None]:
autoencoder = Autoencoder().to(device)
autoencoder.load_state_dict(torch.load('pretrained_autoencoder.pth'))

model = BinaryClassifier(autoencoder.encoder).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    epoch_loss = 0.0

    for inputs, labels in train_loader_bin:
        inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(train_loader_bin):.4f}")

torch.save(model.state_dict(), "binary_classifier.pth")


Epoch 1, Loss: 5.7800
Epoch 2, Loss: 0.9497
Epoch 3, Loss: 0.2756
Epoch 4, Loss: 0.1725
Epoch 5, Loss: 0.1540
Epoch 6, Loss: 0.1377
Epoch 7, Loss: 0.1237
Epoch 8, Loss: 0.1196
Epoch 9, Loss: 0.1144
Epoch 10, Loss: 0.1085


# **Evaluation + Report**

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader_bin:
        inputs = inputs.to(device)
        outputs = model(inputs)
        preds = (outputs.cpu().numpy() > 0.5).astype(int).flatten()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print("Accuracy:", accuracy_score(all_labels, all_preds))
print("Precision:", precision_score(all_labels, all_preds))
print("Recall:", recall_score(all_labels, all_preds))
print("Confusion Matrix:\n", confusion_matrix(all_labels, all_preds))


Accuracy: 0.9638
Precision: 0.9725230645808263
Recall: 0.9556562869530942
Confusion Matrix:
 [[4789  137]
 [ 225 4849]]
