## Learn about class and pytorch
runs with 3.8.2 on linux and 3.8.8 on mac

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Setup tensorboard

In [None]:
#!pip install tensorflow
#!pip install tensorflow --upgrade
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

# setup wandb

In [None]:
import wandb
wandb.finish


In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "AoKerasStudio.ipynb"

sweep_config = {
    "method": "random", # try grid or random
    "metric": {
        "name": "RmsLossValid",
        "goal": "minimize"
    },
    "parameters": {

    "learning_rate" :{
        "values": [ 0.005, 0.001]
        }, 
    "lrFactor": {
        "values": [0.5]
    },                  
  
    "batch_size": {
        "values": [512]
    },
    "epochs": {
        "values": [100]
    }, 
    # "NoLayers": {
    #     "values": [3, 4]
    # },     
    
    "layer1": {
        "values": [2000, 3000, 4000]
    },
    "batchNorm": {
        "values": [0, 1]
    },    
    "layer2": {
        "values": [1050, 2050]
    },               
    "layer3": {
        "values": [200, 400, 1050, 2050]    }                
    }
}

# Build data loader

In [None]:
import pyarrow.feather as feather
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd



In [None]:
def build_dataset(batchSizeTrain, batchsizeValid):
    fluxData_df = feather.read_feather('data/fluxData.feather')
    # normalize input data
    fluxData_df_norm = (fluxData_df - fluxData_df.mean(axis=0)) / fluxData_df.std(axis=0)
    zernikeData_df = feather.read_feather('data/zernikeData.feather')
    X_train, X_val, y_train, y_val = train_test_split(fluxData_df_norm, zernikeData_df, test_size=0.2, random_state=42)


    train_target = torch.tensor(y_train.values.astype(np.float32))
    trainInput = torch.tensor(X_train.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(trainInput, train_target) 
    loaderTrain = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchSizeTrain, shuffle = True)


    valid_target = torch.tensor(y_val.values.astype(np.float32))
    validInput = torch.tensor(X_val.values.astype(np.float32))

    valid_tensor = torch.utils.data.TensorDataset(validInput, valid_target) 
    loaderValid = torch.utils.data.DataLoader(dataset = valid_tensor, batch_size = batchsizeValid, shuffle = True)
    return loaderTrain, loaderValid



# Build data loader

## Build AO network as class

In [None]:
config["Layer1"] 

In [None]:
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self, config):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(19,config.layer1)
        self.linear1_bn = nn.BatchNorm1d(config.layer1)
        self.linear2 = nn.Linear(config.layer1,config.layer2)
        self.linear2_bn = nn.BatchNorm1d(config.layer2)

        self.linear3 = nn.Linear(config.layer2,config.layer3)
        self.linear3_bn = nn.BatchNorm1d(config.layer3)

        #self.linear4 = nn.Linear(config.layer3,config.layer4)        
        self.out = nn.Linear(config.layer3,9)
        self.dropout = nn.Dropout(p=0.0001)
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'lrelu': nn.LeakyReLU()
    })

    def forward(self, x, act = "relu"):
        x = self.linear1(x)
        if wandb.config.batchNorm ==1:
                  x = self.linear1_bn(x)
        x = F.relu(x)
        x = self.linear2(x)
        if wandb.config.batchNorm ==1:
                  x = self.linear2_bn(x)
        x = F.relu(x)

        x = self.linear3(x)
        if wandb.config.batchNorm ==1:
                  x = self.linear3_bn(x)
        x = F.relu(x)
        # x = self.dropout(x)        

      #  x = self.linear4(x)
      #  x = F.relu(x) 
       # x = self.dropout(x)        
        x = self.out(x)
        return(x)


In [None]:
def train(config=None):


    wandb.init(project="PytorchWandbSweep")
    config = wandb.config
    print(config)    
    ClassNetwork = Net(config).to(device)    
    wandb.watch(ClassNetwork, log_freq=10)

    optimizer = optim.AdamW(ClassNetwork.parameters(), lr=config.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config.lrFactor, cooldown=5, patience=5,
    verbose=True)
    loaderTrain, loaderValid =build_dataset(config.batch_size, 2048)
    epochs = config.epochs
    for epoch in range(epochs):
        ClassNetwork.train()
        for i, data in enumerate(loaderTrain,0):
            input, labels = data
            input = input.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
   #         outputs = ClassNetwork(input, "relu")
            loss = nn.MSELoss()
            loss =loss(ClassNetwork(input, "relu"), labels)
    # ⬅ Backward pass + weight update
            loss.backward()
            optimizer.step()
        for i, data in enumerate(loaderValid,0):
            input, labels = data
            input = input.to(device)
            labels = labels.to(device)            
            ClassNetwork.eval()
            with torch.no_grad():
                lossVal = nn.MSELoss()            
                lossVal =lossVal(ClassNetwork(input, "relu"), labels)
                RmsLossValid=torch.sqrt(lossVal) 
        
        scheduler.step(RmsLossValid.item())  
        actualLR = optimizer.param_groups[0]["lr"]          
        print("epoch: ", epoch, "loss: ", "%.6f" % loss.item() ,"RmsLossValid: " , "%.6f" % RmsLossValid.item(), "Learning rate:", "%.8f" % actualLR)  
        writer.add_scalar('Loss/train', loss.item(), epoch)
        writer.add_scalar('RmsLossValid/valid', RmsLossValid.item(), epoch)
        wandb.log({"loss": loss, "RmsLossValid": RmsLossValid, "LearningRate": actualLR})
    data, labels = next(iter(loaderValid))
    data = data.to(device)
    labels = labels.to(device)    
    writer.add_graph(ClassNetwork, data  )
    writer.close
    torch.save({
            'epoch': epoch,
            'model_state_dict': ClassNetwork.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            "RmsLossValid": RmsLossValid
            }, "PaperModel.pt")

In [None]:
""" ClassNetwork = Net()
print(ClassNetwork)
optimizer = optim.AdamW(ClassNetwork.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
loaderTrain, loaderValid =build_dataset(128,128)
epochs = 8
for epoch in range(epochs):
    ClassNetwork.train()
    for i, data in enumerate(loaderTrain,0):
        input, labels = data
        optimizer.zero_grad()
        outputs = ClassNetwork(input, "relu")
        loss = nn.MSELoss()
        loss =loss(ClassNetwork(input, "relu"), labels)
# ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()
    for i, data in enumerate(loaderValid,0):
        ClassNetwork.eval()
        lossVal = nn.MSELoss()            
        lossVal =lossVal(ClassNetwork(input, "relu"), labels)
        RmsLossValid=torch.sqrt(loss) 
        scheduler.step(RmsLossValid)   
    print(optimizer.param_groups[0]["lr"])
    print("epoch: ", epoch, "loss: ", "%.4f" % loss.item() ,"RmsLossValid: " , "%.4f" % RmsLossValid.item())  
 """

In [None]:
# ClassNetwork = Net()

# loaderTrain, loaderValid =build_dataset(1024,1024)
# data, labels = next(iter(loaderValid))
# ClassNetwork( data, "relu")

In [None]:
sweep_id = wandb.sweep(sweep_config, project="PytorchSweepWorking")
wandb.run

In [None]:
wandb.agent(sweep_id, train, count=100)

In [None]:
#train(config)