In [1]:
from ML4transients.training import PytorchDataset, CustomCNN, get_trainer, get_loss_function

In [4]:
datasets = PytorchDataset.create_splits('/sps/lsst/groups/transients/HSC/fouchez/raphael/UDEEP_COSMOS2', random_state=42)

Building sample index...
Creating splits from 8490713 samples...
Loading 5943498 cutouts...
Loading 849072 cutouts...
Loading 1698143 cutouts...


In [5]:
val_dataset = datasets['val'] 


In [8]:
from torch.utils.data import DataLoader

val_loader = DataLoader(
    val_dataset, 
    batch_size=64, 
    shuffle=False,
    num_workers=0
)

In [3]:
train_dataset = datasets['train'] 
test_dataset = datasets['test'] 
val_dataset = datasets['val'] 



In [None]:
from torch.utils.data import DataLoader

config = {
    "epochs": 200,
    "learning_rate": 0.001,
    "batch_size": 128,
    "num_iter_per_epoch": 400,
    "epoch_decay_start": 80,
    "num_workers": 0,
    "model_params": {
        "input_shape": [30, 30, 1],
        "num_classes": 2,
        "filters_1": 32,
        "filters_2": 64,
        "dropout_1": 0.25,
        "dropout_2": 0.25,
        "dropout_3": 0.5,
        "units": 128
    }
}

train_loader = DataLoader(
    train_dataset, 
    batch_size=config['batch_size'],  # 128 from your config
    shuffle=True,
    num_workers=config['num_workers']  # 4 from your config
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=config['batch_size'], 
    shuffle=False,
    num_workers=config['num_workers']
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config['batch_size'], 
    shuffle=False,
    num_workers=config['num_workers']
)
trainer = get_trainer("standard", config)




In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CustomCNN(nn.Module):
    def __init__(self, input_shape=(30, 30, 1), num_classes=2, filters_1=32, filters_2=64, 
                 dropout_1=0.25, dropout_2=0.25, dropout_3=0.5, units=128):
        super(CustomCNN, self).__init__()
        
        # Input shape: (batch_size, 1, 30, 30)
        self.conv1 = nn.Conv2d(1, filters_1, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(dropout_1)
        
        self.conv2 = nn.Conv2d(filters_1, filters_2, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout2d(dropout_2)
        
        # Calculate flattened size after convolutions
        # After two 2x2 pooling operations: 30 -> 15 -> 7 (with floor division)
        self.flattened_size = filters_2 * 7 * 7
        
        self.fc1 = nn.Linear(self.flattened_size, units)
        self.dropout3 = nn.Dropout(dropout_3)
        self.fc2 = nn.Linear(units, 1)  # Single output for binary classification
        
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.dropout1(x)
        
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.dropout2(x)
        
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout3(x)
        x = self.fc2(x)
        
        return x    

model = CustomCNN()

In [22]:
model(torch.zeros(1, 1, 30, 30))

tensor([[0.1114]], grad_fn=<AddmmBackward0>)

In [28]:
trainer.train_one_epoch(1, train_loader)

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/pbs/home/r/rbonnetguerrini/.conda/envs/env_ML/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/pbs/home/r/rbonnetguerrini/.conda/envs/env_ML/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/pbs/home/r/rbonnetguerrini/.conda/envs/env_ML/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/pbs/home/r/rbonnetguerrini/.conda/envs/env_ML/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/pbs/home/r/rbonnetguerrini/.conda/envs/env_ML/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/pbs/home/r/rbonnetguerrini/.conda/envs/env_ML/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = se

{'loss': 0.5919589725343587, 'accuracy': 0.8809118496040264}