In [None]:
"""Modified from https://skorch.readthedocs.io/en/stable/user/quickstart.html"""
from skorch.regressor import NeuralNetRegressor
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skorch.callbacks import EarlyStopping


# Regular PyTorch Module
class ConvNet(torch.nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        # num_classes is used by the corn loss function
        self.num_classes = num_classes

        # Initialize CNN layers
        all_layers = [
            torch.nn.Conv2d(in_channels=in_channels, out_channels=3, 
                            kernel_size=(3, 3), stride=(1, 1), 
                            padding=1),
            torch.nn.MaxPool2d(kernel_size=(2, 2),  stride=(2, 2)),
            torch.nn.Conv2d(in_channels=3, out_channels=6, 
                            kernel_size=(3, 3), stride=(1, 1), 
                            padding=1),
            torch.nn.MaxPool2d(kernel_size=(2, 2),  stride=(2, 2)),
            torch.nn.Flatten()
        ]
        
        # CORN output layer --------------------------------------
        # Regular classifier would use num_classes instead of 
        # num_classes-1 below
        output_layer = torch.nn.Linear(294, num_classes-1)
        # ---------------------------------------------------------
        
        all_layers.append(output_layer)
        self.model = torch.nn.Sequential(*all_layers)
        
    def forward(self, x):
        x = self.model(x)
        return x

net = NeuralNetRegressor(
    ConvNet(dropout=0.1),
    criterion=nn.MSELoss,
    max_epochs=100,
    optimizer=optim.Adam,
    optimizer__lr = .005,
    callbacks=[EarlyStopping()],
)

net.fit(X_train, y_train)
y_pred = net.predict(X_val)

mae = mean_absolute_error(y_val, y_pred)
rmse = mean_squared_error(y_val, y_pred, squared=False)