## Learn about class and pytorch
runs with 3.8.2 on linux and 3.8.8 on mac

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Setup tensorboard

In [None]:
#!pip install tensorflow
#!pip install tensorflow --upgrade
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

# Build data loader

In [None]:
import pyarrow.feather as feather
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd



In [None]:
def build_dataset(batchSizeTrain, batchsizeValid):
    fluxData_df = feather.read_feather('data/fluxData.feather')
    # normalize input data
    fluxData_df_norm = (fluxData_df - fluxData_df.mean(axis=0)) / fluxData_df.std(axis=0)
    zernikeData_df = feather.read_feather('data/zernikeData.feather')
    X_train, X_val, y_train, y_val = train_test_split(fluxData_df_norm, zernikeData_df, test_size=0.2, random_state=42)


    train_target = torch.tensor(y_train.values.astype(np.float32))
    trainInput = torch.tensor(X_train.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(trainInput, train_target) 
    loaderTrain = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchSizeTrain, shuffle = True)


    valid_target = torch.tensor(y_val.values.astype(np.float32))
    validInput = torch.tensor(X_val.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(validInput, valid_target) 
    loaderValid = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchsizeValid, shuffle = True)
    return loaderTrain, loaderValid



# Build data loader

## Build AO network as class

In [None]:
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(19,2000)
        self.linear2 = nn.Linear(2000,1050)
        self.linear3 = nn.Linear(1050,100)
        self.out = nn.Linear(100,9)
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'lrelu': nn.LeakyReLU()
    })

    def forward(self, x, act = "relu"):
        x = self.linear1(x)
        x = F.relu(x)
        #x = self.activations[act](x)
        x = self.linear2(x)
        x = F.relu(x)
        #x = self.activations[act](x)
        x = self.linear3(x)
        #x = self.activations[act](x)
        x = F.relu(x)
        x = self.out(x)
        return(x)





In [None]:
def train(config=None):
    ClassNetwork = Net()
    print(ClassNetwork)
    optimizer = optim.AdamW(ClassNetwork.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8, cooldown=20,
    verbose=True)
    loaderTrain, loaderValid =build_dataset(128,128)
    epochs = 800
    for epoch in range(epochs):
        ClassNetwork.train()
        for i, data in enumerate(loaderTrain,0):
            input, labels = data
            optimizer.zero_grad()
            outputs = ClassNetwork(input, "relu")
            loss = nn.MSELoss()
            loss =loss(ClassNetwork(input, "relu"), labels)
    # ⬅ Backward pass + weight update
            loss.backward()
            optimizer.step()
        for i, data in enumerate(loaderValid,0):
            ClassNetwork.eval()
            lossVal = nn.MSELoss()            
            lossVal =lossVal(ClassNetwork(input, "relu"), labels)
            RmsLossValid=torch.sqrt(loss) 
        
        scheduler.step(RmsLossValid.item())            
        print("epoch: ", epoch, "loss: ", "%.4f" % loss.item() ,"RmsLossValid: " , "%.4f" % RmsLossValid.item())  
        writer.add_scalar('Loss/train', loss.item(), epoch)
        writer.add_scalar('RmsLossValid/valid', RmsLossValid.item(), epoch)
    data, labels = next(iter(loaderValid))
    writer.add_graph(ClassNetwork, data  )
    writer.close

In [None]:
""" ClassNetwork = Net()
print(ClassNetwork)
optimizer = optim.AdamW(ClassNetwork.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
loaderTrain, loaderValid =build_dataset(128,128)
epochs = 8
for epoch in range(epochs):
    ClassNetwork.train()
    for i, data in enumerate(loaderTrain,0):
        input, labels = data
        optimizer.zero_grad()
        outputs = ClassNetwork(input, "relu")
        loss = nn.MSELoss()
        loss =loss(ClassNetwork(input, "relu"), labels)
# ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()
    for i, data in enumerate(loaderValid,0):
        ClassNetwork.eval()
        lossVal = nn.MSELoss()            
        lossVal =lossVal(ClassNetwork(input, "relu"), labels)
        RmsLossValid=torch.sqrt(loss) 
        scheduler.step(RmsLossValid)   
    print(optimizer.param_groups[0]["lr"])
    print("epoch: ", epoch, "loss: ", "%.4f" % loss.item() ,"RmsLossValid: " , "%.4f" % RmsLossValid.item())  
 """

In [None]:
# ClassNetwork = Net()

# loaderTrain, loaderValid =build_dataset(1024,1024)
# data, labels = next(iter(loaderValid))
# ClassNetwork( data, "relu")

In [9]:
train()