## Learn about class and pytorch
runs with 3.8.2 on linux and 3.8.8 on mac

In [139]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Setup tensorboard

In [140]:
#!pip install tensorflow
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

# Build data loader

In [141]:
import pyarrow.feather as feather
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd



In [142]:
def build_dataset(batchSizeTrain, batchsizeValid):
    fluxData_df = feather.read_feather('data/fluxData.feather')
    # normalize input data
    fluxData_df_norm = (fluxData_df - fluxData_df.mean(axis=0)) / fluxData_df.std(axis=0)
    zernikeData_df = feather.read_feather('data/zernikeData.feather')
    X_train, X_val, y_train, y_val = train_test_split(fluxData_df_norm, zernikeData_df, test_size=0.2, random_state=42)


    train_target = torch.tensor(y_train.values.astype(np.float32))
    trainInput = torch.tensor(X_train.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(trainInput, train_target) 
    loaderTrain = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchSizeTrain, shuffle = True)


    valid_target = torch.tensor(y_val.values.astype(np.float32))
    validInput = torch.tensor(X_val.values.astype(np.float32))

    train_tensor = torch.utils.data.TensorDataset(validInput, valid_target) 
    loaderValid = torch.utils.data.DataLoader(dataset = train_tensor, batch_size = batchsizeValid, shuffle = True)
    return loaderTrain, loaderValid



# Build data loader

## Build AO network as class

In [143]:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(19,10)
        self.linear2 = nn.Linear(10,10)
        self.linear3 = nn.Linear(10,100)
        self.out = nn.Linear(100,9)
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'lrelu': nn.LeakyReLU()
    })

    def forward(self, x, act = "relu"):
        x = self.linear1(x)
        x = self.activations[act](x)
        x = self.linear2(x)
        x = self.activations[act](x)
        x = self.linear3(x)
        x = self.activations[act](x)
        x = self.out(x)
        return(x)





In [144]:
def train(config=None):
    ClassNetwork = Net()
    print(ClassNetwork)
    optimizer = optim.SGD(ClassNetwork.parameters(), lr=0.001, momentum=0.9)
    loaderTrain, loaderValid =build_dataset(1024,1024)
    optimizer = optim.SGD(ClassNetwork.parameters(), lr=0.001, momentum=0.9)
    loaderTrain, loaderValid =build_dataset(1024,1024)
    epochs = 5
    for epoch in range(epochs):
        for i, data in enumerate(loaderValid,0):
            input, labels = data
            optimizer.zero_grad()
            outputs = ClassNetwork(input, "relu")
            loss = nn.MSELoss()
            loss =loss(ClassNetwork(input, "relu"), labels)
            RmsLossValid=torch.sqrt(loss)
    # ⬅ Backward pass + weight update
            loss.backward()
            optimizer.step()
        print("epoch: ", epoch, "loss: ", "%.4f" % loss.item() ,"RmsLossValid: " , "%.4f" % RmsLossValid.item())  
        writer.add_scalar('Loss/train', loss.item(), epoch)
        writer.add_scalar('RmsLossValid/train', RmsLossValid.item(), epoch)
    data, labels = next(iter(loaderValid))
    writer.add_graph(ClassNetwork, data  )
    writer.close

In [145]:
# ClassNetwork = Net()

# loaderTrain, loaderValid =build_dataset(1024,1024)
# data, labels = next(iter(loaderValid))
# ClassNetwork( data, "relu")

In [146]:
train()

Net(
  (linear1): Linear(in_features=19, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=10, bias=True)
  (linear3): Linear(in_features=10, out_features=100, bias=True)
  (out): Linear(in_features=100, out_features=9, bias=True)
  (activations): ModuleDict(
    (relu): ReLU()
    (lrelu): LeakyReLU(negative_slope=0.01)
  )
)
epoch:  0 loss:  0.0414 RmsLossValid:  0.2035
epoch:  1 loss:  0.0388 RmsLossValid:  0.1970
epoch:  2 loss:  0.0392 RmsLossValid:  0.1979
epoch:  3 loss:  0.0383 RmsLossValid:  0.1957
epoch:  4 loss:  0.0380 RmsLossValid:  0.1950
