In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [15]:
####Loading directories
npz_path = "/home/tharushi/Desktop/hackathon/10061/patches_128/real_plus_fft_dataset_128.npz"
data = np.load(npz_path)
X = data["X"]
y = data["y"]

print("X:", X.shape, X.dtype)  # (N,2,128,128)
print("y:", y.shape, y.dtype) 

X: (4000, 2, 128, 128) float32
y: (4000,) int64


In [16]:
class NPZDataset(Dataset):
    def __init__(self, npz_path):
        d = np.load(npz_path)
        self.X = d["X"].astype(np.float32)
        self.y = d["y"].astype(np.int64)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx], dtype=torch.long)

dataset = NPZDataset(npz_path)
print("dataset size:", len(dataset), "pos rate:", dataset.y.mean())


dataset size: 4000 pos rate: 0.5


In [17]:
class IceCNN_FFT(nn.Module):
    def __init__(self, in_channels=2):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 128 -> 64

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 64 -> 32

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 32 -> 16

            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(128, 2)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

model = IceCNN_FFT(in_channels=X.shape[1]).to(DEVICE)


In [18]:
indices = list(range(len(dataset)))

train_idx, val_idx = train_test_split(
    indices, test_size=0.2, stratify=dataset.y, random_state=42
)

train_ds = torch.utils.data.Subset(dataset, train_idx)
val_ds   = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)

# sanity check batch shapes
xb, yb = next(iter(train_loader))
print("batch X:", xb.shape, "batch y:", yb.shape)  # should be (64,2,128,128) and (64,)


batch X: torch.Size([64, 2, 128, 128]) batch y: torch.Size([64])


In [19]:
EPOCHS = 5
LR = 1e-3

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    total = 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total += imgs.size(0)

    avg_loss = total_loss / total

    # validation
    model.eval()
    val_preds, val_true = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            logits = model(imgs)
            preds = logits.argmax(dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_true.extend(labels.cpu().numpy())

    print(f"\nEpoch {epoch}/{EPOCHS} - train loss: {avg_loss:.4f}")
    print(classification_report(val_true, val_preds, target_names=["BAD", "GOOD"]))  # adjust names if needed



Epoch 1/5 - train loss: 0.6935
              precision    recall  f1-score   support

         BAD       0.00      0.00      0.00       400
        GOOD       0.50      1.00      0.67       400

    accuracy                           0.50       800
   macro avg       0.25      0.50      0.33       800
weighted avg       0.25      0.50      0.33       800



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2/5 - train loss: 0.6932
              precision    recall  f1-score   support

         BAD       0.00      0.00      0.00       400
        GOOD       0.50      1.00      0.67       400

    accuracy                           0.50       800
   macro avg       0.25      0.50      0.33       800
weighted avg       0.25      0.50      0.33       800



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3/5 - train loss: 0.6932
              precision    recall  f1-score   support

         BAD       0.00      0.00      0.00       400
        GOOD       0.50      1.00      0.67       400

    accuracy                           0.50       800
   macro avg       0.25      0.50      0.33       800
weighted avg       0.25      0.50      0.33       800



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4/5 - train loss: 0.6932
              precision    recall  f1-score   support

         BAD       0.50      1.00      0.67       400
        GOOD       0.00      0.00      0.00       400

    accuracy                           0.50       800
   macro avg       0.25      0.50      0.33       800
weighted avg       0.25      0.50      0.33       800



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5/5 - train loss: 0.6932
              precision    recall  f1-score   support

         BAD       0.00      0.00      0.00       400
        GOOD       0.50      1.00      0.67       400

    accuracy                           0.50       800
   macro avg       0.25      0.50      0.33       800
weighted avg       0.25      0.50      0.33       800



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
