In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ZTFDataset(Dataset):
  def __init__(self, npz_path, train, split):
    data = np.load(npz_path)
    images = data['X_train']
    labels = data['y_train']

    split = int(len(labels) * split)
    if train:
        self.images = images[:split]
        self.labels = labels[:split]
    else:
        self.images = images[split:]
        self.labels = labels[split:]

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    image = self.images[idx].astype(np.float32)
    label = int(self.labels[idx])
    return image, label

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class PyTorchCNN(nn.Module):
  def __init__(self):
    super(PyTorchCNN, self).__init__()

    # Block 1:
    self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
    self.conv1_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
    self.pool1 = nn.MaxPool2d(stride=2, kernel_size=2)
    self.dropout1 = nn.Dropout(0.25)

    # Block 2:
    self.conv2_1 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    self.conv2_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
    self.pool2 = nn.MaxPool2d(stride=4, kernel_size=2)
    self.dropout2 = nn.Dropout(0.25)

    # Fully connected layers
    self.fc1 = nn.Linear(in_features=2048, out_features=256)
    self.dropout3 = nn.Dropout(0.5)
    self.fc2 = nn.Linear(in_features=256, out_features=1)

  def forward(self, x):
    # Block 1:
    x = F.relu(self.conv1_1(x))
    x = F.relu(self.conv1_2(x))
    x = self.pool1(x)
    x = self.dropout1(x)

    # Block 2:
    x = F.relu(self.conv2_1(x))
    x = F.relu(self.conv2_2(x))
    x = self.pool2(x)
    x = self.dropout2(x)

    # full connected
    x = x.view(x.size(0), -1)
    x = F.relu(self.fc1(x))
    x = self.dropout3(x)
    x = self.fc2(x)
    x = F.sigmoid(x)

    return x

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
import torch.optim as optim

model = PyTorchCNN().to(device)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
def train(model, train_loader, val_loader, epochs, patience):
  model.train()
  best_loss = float('inf')
  best_model_state = None
  epochs_no_improve = 0

  for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    correct, total = 0, 0

    for inputs, labels in train_loader:
      inputs = inputs.to(device).float()
      labels = labels.to(device).unsqueeze(1).float()

      optimizer.zero_grad()
      outputs = model(inputs)

      labels = labels.view(-1, 1)  # Ensures shape is (batch_size, 1)
      loss = loss_fn(outputs, labels)
      loss.backward()
      optimizer.step()

      total_loss += loss.item() * inputs.size(0)
      preds = (outputs > 0.5).float()
      correct += (preds == labels).sum().item()
      total += labels.size(0)

    train_loss = total_loss / total
    train_acc = correct / total

    model.eval()
    val_loss = 0.0
    val_correct, val_total = 0, 0

    with torch.no_grad():
      for inputs, labels in val_loader:
        inputs = inputs.to(device).float()
        labels = labels.to(device).unsqueeze(1).float()

        outputs = model(inputs)
        labels = labels.view(-1, 1)  # Ensures shape is (batch_size, 1)
        loss = loss_fn(outputs, labels)

        val_loss += loss.item() * inputs.size(0)
        preds = (outputs > 0.5).float()
        val_correct += (preds == labels).sum().item()
        val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    if val_loss < best_loss:
      best_loss = val_loss
      best_model_state = model.state_dict()
      epochs_no_improve = 0
    else:
      epochs_no_improve += 1
      if epochs_no_improve >= patience:
        print("Early stopping triggered")
        break

    if best_model_state:
      model.load_state_dict(best_model_state)

In [7]:
dataset_path = "/content/drive/MyDrive/braai_cnn/ztf_dataset_split.npz"

train_dataset = ZTFDataset(dataset_path, train=True, split=0.8)
val_dataset = ZTFDataset(dataset_path, train=False, split=0.8)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [8]:
train(model, train_loader, val_loader, epochs=200, patience=10)

Epoch 1/200, Train Loss: 0.6627, Train Acc: 0.6039, Val Loss: 0.6287, Val Acc: 0.6456
Epoch 2/200, Train Loss: 0.6060, Train Acc: 0.6636, Val Loss: 0.5422, Val Acc: 0.7026
Epoch 3/200, Train Loss: 0.5362, Train Acc: 0.7224, Val Loss: 0.5029, Val Acc: 0.7411
Epoch 4/200, Train Loss: 0.5038, Train Acc: 0.7464, Val Loss: 0.4690, Val Acc: 0.7582
Epoch 5/200, Train Loss: 0.4822, Train Acc: 0.7568, Val Loss: 0.4451, Val Acc: 0.7774
Epoch 6/200, Train Loss: 0.4640, Train Acc: 0.7747, Val Loss: 0.4354, Val Acc: 0.7851
Epoch 7/200, Train Loss: 0.4542, Train Acc: 0.7807, Val Loss: 0.4190, Val Acc: 0.7945
Epoch 8/200, Train Loss: 0.4403, Train Acc: 0.7897, Val Loss: 0.4203, Val Acc: 0.7922
Epoch 9/200, Train Loss: 0.4310, Train Acc: 0.7948, Val Loss: 0.4330, Val Acc: 0.7832
Epoch 10/200, Train Loss: 0.4223, Train Acc: 0.7981, Val Loss: 0.4314, Val Acc: 0.7834
Epoch 11/200, Train Loss: 0.4132, Train Acc: 0.8054, Val Loss: 0.3864, Val Acc: 0.8117
Epoch 12/200, Train Loss: 0.4075, Train Acc: 0.8106,

In [10]:
torch.save({
    'model': model.state_dict(),
}, '/content/drive/MyDrive/braai_cnn/models/pytorch_cnn.pt')