In [None]:
#Purpose
# - Check distribution of the data and different atom systems. 

In [2]:
import plotly.graph_objects as go
import numpy as np

In [14]:
#Load the data of one system, plot it and check its properties.
from FeGdSpinGNN.FeGd_dataset import FeGdMagneticDataset

dataset_train = FeGdMagneticDataset(
        root=r'FeGd',
        systems=[3],
        cutoff_dist=0.3,  # example cutoff distance
        use_static_features=False, # probaly not needed for now  
        include_z=True,
        load_energy=True,
    )

Loading systems: 100%|██████████| 1/1 [00:03<00:00,  3.41s/it]


In [15]:
from torch.utils.data import Subset
import torch

def indices_with_energy(ds):
    idxs = []
    for i in range(len(ds)):
        data = ds[i]
        if hasattr(data, "energy") and torch.is_tensor(data.energy):
            idxs.append(i)
        else:
            print(f"Dropping graph {i} from dataset (missing or invalid energy)")
    return idxs

train_idx = indices_with_energy(dataset_train)


dataset_train = Subset(dataset_train, train_idx)

Dropping graph 500 from dataset (missing or invalid energy)


In [16]:
data = dataset_train[50]
positions = data['pos'].numpy()
atom_types = data['x'][:,0].numpy()
spins = data['x'][:,2:5].numpy()
edge_index = data['edge_index'].numpy()
bfields = data['y'].numpy()

print(bfields)


[[  119.58839   1147.9271   -1363.0504  ]
 [-1122.3354     747.0981     577.43317 ]
 [  118.542336  -169.59465   -540.31555 ]
 ...
 [ 1257.7017   -1407.1543    4860.706   ]
 [  -58.937363 -3727.2053    1288.8804  ]
 [ 1388.2834    -101.5975    3310.0393  ]]


In [17]:
import plotly.graph_objects as go
import numpy as np

# Get all edges
src, dst = edge_index[0], edge_index[1]

# Create edge traces for ALL edges
edge_x, edge_y, edge_z = [], [], []
for i in range(edge_index.shape[1]):
    s, d = edge_index[0, i], edge_index[1, i]
    edge_x.extend([positions[s, 0], positions[d, 0], None])
    edge_y.extend([positions[s, 1], positions[d, 1], None])
    edge_z.extend([positions[s, 2], positions[d, 2], None])

fig = go.Figure()

# Add ALL edges
fig.add_trace(go.Scatter3d(
    x=edge_x, y=edge_y, z=edge_z,
    mode='lines',
    line=dict(color='gray', width=1),
    name='All connections',
    hoverinfo='none',
    showlegend=True
))

# Add all atoms by type
for atom_type, label, color in [(0, 'Gd', 'blue'), (1, 'Fe', 'orange')]:
    mask = atom_types == atom_type
    fig.add_trace(go.Scatter3d(
        x=positions[mask, 0],
        y=positions[mask, 1],
        z=positions[mask, 2],
        mode='markers',
        name=label,
        marker=dict(size=3, color=color)
    ))

fig.update_layout(
    width=1000,
    height=800,
    scene=dict(aspectmode='data'),
    title=f'Full graph structure: {edge_index.shape[1]} edges, {len(positions)} atoms'
)
fig.show()

print(f"Total edges: {edge_index.shape[1]}")
print(f"Total atoms: {len(positions)}")
print(f"Average degree: {edge_index.shape[1] / len(positions):.1f}")

Total edges: 9408
Total atoms: 800
Average degree: 11.8


In [46]:
# Choose which atom to show connections for
center_atom = 100  # Change this index

# Find edges connected to this atom
src, dst = edge_index[0], edge_index[1]
connected_edges = (src == center_atom) | (dst == center_atom)
highlight_edge_index = edge_index[:, connected_edges]

# Create edge traces
edge_x, edge_y, edge_z = [], [], []
for i in range(highlight_edge_index.shape[1]):
    s, d = highlight_edge_index[0, i], highlight_edge_index[1, i]
    edge_x.extend([positions[s, 0], positions[d, 0], None])
    edge_y.extend([positions[s, 1], positions[d, 1], None])
    edge_z.extend([positions[s, 2], positions[d, 2], None])

fig = go.Figure()

# Add edges
fig.add_trace(go.Scatter3d(
    x=edge_x, y=edge_y, z=edge_z,
    mode='lines',
    line=dict(color='red', width=2),
    name=f'Connections from atom {center_atom}',
    hoverinfo='none'
))

# Add all atoms by type
for atom_type, label, color in [(0, 'Gd', 'blue'), (1, 'Fe', 'orange')]:
    mask = atom_types == atom_type
    fig.add_trace(go.Scatter3d(
        x=positions[mask, 0],
        y=positions[mask, 1],
        z=positions[mask, 2],
        mode='markers',
        name=label,
        marker=dict(size=3, color=color)
    ))

# Highlight the center atom
fig.add_trace(go.Scatter3d(
    x=[positions[center_atom, 0]],
    y=[positions[center_atom, 1]],
    z=[positions[center_atom, 2]],
    mode='markers',
    name='Center atom',
    marker=dict(size=3, color='red', line=dict(color='black', width=2))
))

fig.update_layout(
    width=1000,
    height=800,
    scene=dict(aspectmode='data')
)
fig.show()

print(f"Atom {center_atom} has {highlight_edge_index.shape[1]} connections")

Atom 100 has 18 connections
