In [1]:
import torch
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.data import Batch

In [2]:
from GNN import GNN
from Dataset import XASDataset


In [3]:
def train(epoch):
    model.train()
    loss_all = 0

    for batch in train_loader:
        batch = batch.to(device)
        #x, edge_index,index = batch.x,batch.edge_index,batch.index
       
        # Add batch dimension to index
        #batch_index = index.unsqueeze(1)

        optimizer.zero_grad()
        
        pred = model(batch)
        

        loss = F.mse_loss(pred.view(-1, 1).double(), 
                        batch.y.view(-1, 1).double())
        loss.backward()
       # print(loss)
        loss_all += loss.item() * batch.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)

In [4]:
def test(loader):
    model.eval()
    loss_all = 0

    for batch in loader:
        batch = batch.to(device)
        #x, edge_index, index = batch.x, batch.edge_index, batch.index
        
        # Add batch dimension to index
        #batch_index = index.unsqueeze(1)

        with torch.no_grad():
            pred = model(batch)

        loss = F.mse_loss(pred.view(-1, 1).double(),
                          batch.y.view(-1, 1).double())
        loss_all += loss.item() * batch.num_graphs

    return loss_all / len(loader.dataset)


In [5]:
path='E:/hlrn_orca/'
dataset = XASDataset(path)

Processing...


[54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 55, 55, 55, 55, 55, 56, 57, 57, 54, 54, 54, 54, 54, 54, 54, 54, 55, 55, 55, 55, 56, 56, 56, 56, 54, 55, 54, 55, 55, 55, 55, 56, 56, 57, 57, 57, 55, 55, 55, 55, 56, 57, 55, 55, 55, 56, 56, 56, 56, 56, 57, 57, 57, 57, 57, 57, 57, 58, 59, 55, 55, 55, 55, 56, 55, 55, 55, 56, 56, 56, 58, 55, 55, 56, 55, 56, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 26, 26, 24, 25, 25, 24, 24, 26, 25, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 24, 25, 24, 24, 24, 26, 24, 24, 24, 24, 24, 24, 24, 24, 25, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 26, 26, 25, 25, 25, 26, 26, 25, 25, 27, 25, 25, 27, 28, 27, 25, 25, 25, 25, 25, 25, 25, 25, 26, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 27, 26, 26, 26, 26, 27, 27, 27, 27, 27, 24, 24, 24, 24, 24, 24, 24, 24, 25, 24, 24, 24, 25, 24, 24, 24, 24, 25, 24, 25, 25, 25, 24, 24, 24, 26, 26, 24, 28, 24, 24, 24, 24, 24, 24, 24, 25, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 26,

Done!


In [None]:

train_dataset = dataset[0:9000]
val_dataset = dataset[9000:9800]
test_dataset = dataset[9000:]

In [None]:
print(len(dataset))

In [None]:

test_loader = DataLoader(test_dataset, batch_size=30, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=30, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)

In [None]:
num_tasks=200
num_layers=4
emb_dim=dataset.num_features
print(emb_dim)
in_channels=[int(emb_dim),100,100,100]
print(in_channels)
out_channels=[100,100,100,100]
gnn_type='gat'
heads=int(1)
drop_ratio=0.25
graph_pooling='sum'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN(num_tasks,num_layers,emb_dim,in_channels,out_channels,gnn_type,heads,drop_ratio,graph_pooling)
model= model.to(device)

In [None]:
model

In [None]:
optimizer=torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.9, patience=20,
                                                       min_lr=0.00001)


In [None]:

best_val_error=None
train_losses = []
val_losses = []

for epoch in range(0, 200):
    
    loss = train(epoch)
    
    val_error = test(val_loader)
    scheduler.step(val_error)
    
    train_losses.append(loss)
    val_losses.append(val_error)
    
    current_lr = optimizer.param_groups[0]['lr']
    


    if best_val_error is None or val_error <= best_val_error:
        test_error = test(test_loader)
        best_val_error = val_error
    if epoch%10==0:
        print(f'Epoch: {epoch:03d}, LR: {current_lr:7f}, TrainLoss: {loss:.7f}')

In [None]:
from utils import plot_spectra,plot_learning_curve

In [None]:
torch.save(model.state_dict(), 'model_gnn.pt')

In [None]:
num_e=200

plot_learning_curve(num_e,train_losses,val_losses)

In [None]:

def pred_spec(model,index,test_dataset):
    # Set the model to evaluation mode
    model.eval()

    # Get a single graph from the test dataset
    graph_index = index # Index of the graph you want to predict on
    graph_data = test_dataset[graph_index].to(device)
    batch = Batch.from_data_list([graph_data])

    # Pass the graph through the model
    with torch.no_grad():
        pred = model(batch)

    # Access the predicted output for the single graph
    pred_graph = pred[0]
    true_spectrum = graph_data.y.cpu().numpy()
    predicted_spectrum = pred.cpu().numpy()
    predicted_spectrum = predicted_spectrum.reshape(-1)
    
    return predicted_spectrum,true_spectrum

In [None]:
# Load the saved model
model =GNN(num_tasks,num_layers,emb_dim,in_channels,out_channels,gnn_type,heads,drop_ratio,graph_pooling)
model=model.to(device)
model.load_state_dict(torch.load('model_gnn.pt'))


In [None]:
save_var=1
p1,t1=pred_spec(model,200,test_dataset)

plot_spectra(p1, t1,save_var)

In [None]:

p2,t2=pred_spec(model,22,test_dataset)

plot_spectra(p2, t2,save_var)

In [None]:
spec_examples={'p1':p1,'t1':t1,'p2':p2,'t2':t2}


In [None]:
import pickle 

with open("E:/hlrn_orca/ml_preds.pkl", "wb") as file:
    pickle.dump(spec_examples, file)

In [None]:
t0=train_dataset[30]

In [None]:
t0.x

In [None]:
model.eval()
for batch in train_loader:
    batch=batch.to(device)
    embeddings = model.forward(batch)

In [None]:
embeddings[1][0]

In [None]:
embeddings[1][1]