<a href="https://colab.research.google.com/github/parthkotwal/braai-cnn/blob/main/3_pytorch_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ZTFDataset(Dataset):
  def __init__(self, npz_path):
    data = np.load(npz_path)
    self.images = data['X_train']
    self.labels = data['y_train']

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    image = self.images[idx].astype(np.float32)
    label = int(self.labels[idx])
    return image, label

In [9]:
import torch.nn as nn
import torch.nn.functional as F

class PyTorchCNN(nn.Module):
  def __init__(self):
    super(PyTorchCNN, self).__init__()

    # Block 1:
    self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
    self.conv1_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
    self.pool1 = nn.MaxPool2d(stride=2, kernel_size=2)
    self.dropout1 = nn.Dropout(0.25)

    # Block 2:
    self.conv2_1 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    self.conv2_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
    self.pool2 = nn.MaxPool2d(stride=4, kernel_size=2)
    self.dropout2 = nn.Dropout(0.25)

    # Fully connected layers
    self.fc1 = nn.Linear(in_features=2048, out_features=256)
    self.dropout3 = nn.Dropout(0.5)
    self.fc2 = nn.Linear(in_features=256, out_features=1)

  def forward(self, x):
    # Block 1:
    x = F.relu(self.conv1_1(x))
    x = F.relu(self.conv1_2(x))
    x = self.pool1(x)
    x = self.dropout1(x)

    # Block 2:
    x = F.relu(self.conv2_1(x))
    x = F.relu(self.conv2_2(x))
    x = self.pool2(x)
    x = self.dropout2(x)

    # full connected
    x = x.view(x.size(0), -1)
    x = F.relu(self.fc1(x))
    x = self.dropout3(x)
    x = self.fc2(x)
    x = F.sigmoid(x)

    return x

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [10]:
import torch.optim as optim

model = PyTorchCNN().to(device)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
def train(model, dataloader, epochs=20):
  model.train()

  for epoch in range(epochs):
    total_loss = 0.0
    correct, total = 0, 0

    for inputs, labels in dataloader:
      inputs = inputs.to(device).float()
      labels = labels.to(device).unsqueeze(1).float()

      optimizer.zero_grad()
      outputs = model(inputs)

      labels = labels.view(-1, 1)  # Ensures shape is (batch_size, 1)
      loss = loss_fn(outputs, labels)
      loss.backward()
      optimizer.step()

      total_loss += loss.item() * inputs.size(0)
      preds = (outputs > 0.5).float()
      correct += (preds == labels).sum().item()
      total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total

    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

In [7]:
train_dataset = ZTFDataset("/content/drive/MyDrive/braai_cnn/ztf_dataset_split.npz")
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [12]:
train(model, train_loader, epochs=20)

Epoch 1/20, Loss: 0.6499, Accuracy: 0.6069
Epoch 2/20, Loss: 0.5644, Accuracy: 0.6901
Epoch 3/20, Loss: 0.5330, Accuracy: 0.7187
Epoch 4/20, Loss: 0.5034, Accuracy: 0.7418
Epoch 5/20, Loss: 0.4704, Accuracy: 0.7670
Epoch 6/20, Loss: 0.4501, Accuracy: 0.7834
Epoch 7/20, Loss: 0.4300, Accuracy: 0.7958
Epoch 8/20, Loss: 0.4200, Accuracy: 0.8020
Epoch 9/20, Loss: 0.4106, Accuracy: 0.8092
Epoch 10/20, Loss: 0.4081, Accuracy: 0.8111
Epoch 11/20, Loss: 0.3977, Accuracy: 0.8177
Epoch 12/20, Loss: 0.3885, Accuracy: 0.8228
Epoch 13/20, Loss: 0.3853, Accuracy: 0.8216
Epoch 14/20, Loss: 0.3800, Accuracy: 0.8241
Epoch 15/20, Loss: 0.3751, Accuracy: 0.8285
Epoch 16/20, Loss: 0.3722, Accuracy: 0.8293
Epoch 17/20, Loss: 0.3703, Accuracy: 0.8322
Epoch 18/20, Loss: 0.3673, Accuracy: 0.8322
Epoch 19/20, Loss: 0.3632, Accuracy: 0.8374
Epoch 20/20, Loss: 0.3589, Accuracy: 0.8360


In [14]:
torch.save({
    'model': model.state_dict(),
}, '/content/drive/MyDrive/braai_cnn/pytorch_cnn.pt')