## Learn about class and pytorch
runs with 3.8.2 on linux and 3.8.8 on mac

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Build data loader

In [None]:
import pyarrow.feather as feather
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd



In [None]:
def build_dataset(batchSizeTrain, batchsizeValid):
    fluxData_df = feather.read_feather('data/fluxData.feather')
    # normalize input data
    fluxData_df_norm = (fluxData_df - fluxData_df.mean(axis=0)) / fluxData_df.std(axis=0)
    zernikeData_df = feather.read_feather('data/zernikeData.feather')
    X_train, X_val, y_train, y_val = train_test_split(fluxData_df_norm, zernikeData_df, test_size=0.2, random_state=42)


    train_target = torch.tensor(y_train.values.astype(np.float32))
    trainInput = torch.tensor(X_train.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(trainInput, train_target) 
    loaderTrain = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchSizeTrain, shuffle = True)


    valid_target = torch.tensor(y_val.values.astype(np.float32))
    validInput = torch.tensor(X_val.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(validInput, valid_target) 
    loaderValid = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchsizeValid, shuffle = True)
    return loaderTrain, loaderValid



# Build data loader

## Build AO network as class

In [None]:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(19,wandb.config.layer1)
        self.linear2 = nn.Linear(wandb.config.layer1,1050)
        self.linear3 = nn.Linear(1050,100)
        self.out = nn.Linear(100,9)
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'lrelu': nn.LeakyReLU()
    })

    def forward(self, x, act):
        x = self.linear1(x)
        x = self.activations[act](x)
        x = self.linear2(x)
        x = self.activations[act](x)
        x = self.linear3(x)
        x = self.activations[act](x)
        x = self.out(x)
        return(x)





## Setup W&B

In [None]:
import wandb
#wandb.init(project="test")

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "PytorchClassExperiment.ipynb"

sweep_config = {
    "method": "random", # try grid or random
    "metric": {
        "name": "val_root_mean_squared_error",
        "goal": "minimize"
    },
    "parameters": {

    "learning_rate" :{
        "values": [ 0.005, 0.001]
        }, 
    "lrFactor": {
        "values": [ 0.2, 0.5]
    },                  
  
    "batch_size": {
        "values": [128]
    },
    "epochs": {
        "values": [100]
    }, 
    "NoLayers": {
        "values": [3, 4]
    },     
    
    "layer1": {
        "values": [2000, 3000, 4000]
    },
    "layer2": {
        "values": [1050, 2050]
    },     
    "layer4": {
        "values": [200, 400, 1050, 2050]
    },             
    "layer3": {
        "values": [200, 400, 1050, 2050]    }                
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="pytorch-sweeps")

In [None]:
def train(config=None):
    sweep_id = wandb.sweep(sweep_config)
    wandb.init()
    config = wandb.config
    ClassNetwork = Net()
    print(ClassNetwork)
    optimizer = optim.SGD(ClassNetwork.parameters(), lr=0.001, momentum=0.9)
    loaderTrain, loaderValid =build_dataset(1024,1024)
    wandb.watch(ClassNetwork, log_freq=1)
    optimizer = optim.SGD(ClassNetwork.parameters(), lr=0.001, momentum=0.9)
    loaderTrain, loaderValid =build_dataset(1024,1024)
    epochs = 5
    for epoch in range(epochs):
        for i, data in enumerate(loaderValid,0):
            input, labels = data
            optimizer.zero_grad()
            outputs = ClassNetwork(input, "relu")
            loss = nn.MSELoss()
            loss =loss(ClassNetwork(input, "relu"), labels)
            RmsLossValid=torch.sqrt(loss)
    # ⬅ Backward pass + weight update
            loss.backward()
            optimizer.step()
        print("epoch: ", epoch, "loss: ", "%.4f" % loss.item() ,"RmsLossValid: " , "%.4f" % RmsLossValid.item())  
        wandb.log({"epoch": epoch, "loss": loss.item(), "RmsLossValid": RmsLossValid.item()})


In [None]:
wandb.agent(sweep_id, train, count=2)

In [None]:
train()