In [13]:
pip install xarray-datatree psycopg2-binary zarr

Collecting zarr
  Downloading zarr-2.14.2-py3-none-any.whl (203 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.3/203.3 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Collecting asciitree (from zarr)
  Downloading asciitree-0.3.3.tar.gz (4.0 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting fasteners (from zarr)
  Downloading fasteners-0.18-py3-none-any.whl (18 kB)
Collecting numcodecs>=0.10.0 (from zarr)
  Downloading numcodecs-0.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Building wheels for collected packages: asciitree
  Building wheel for asciitree (setup.py) ... [?25ldone
[?25h  Created wheel for asciitree: filename=asciitree-0.3.3-py3-none-any.whl size=5034 sha256=1ace7ace32e7068f0438bf1d651e26db665f2bbd1266f42a50754874c48d5116
  Stored in directory: /home/jovyan/.cache/pip/wheels

In [2]:
import pandas as pd
from sqlalchemy.orm import sessionmaker
from src.db_utils import connect
from datatree import open_datatree
import matplotlib.pyplot as plt
import yaml
from pathlib import Path

with Path('config.yml').open() as handle:
    config = yaml.load(handle, yaml.Loader)
    
URI = f"postgresql://root:root@{config['host']}:{config['port']}/mast_db"
metadata, engine = connect(URI)
Session = sessionmaker(bind = engine)
session = Session()

### Database Stats

In [3]:
num_shots = session.query(metadata.tables['shots']).count()
num_signals = session.query(metadata.tables['signals']).count()

print(f'Number of shots: {num_shots}')
print(f'Number of signals: {num_signals}')

Number of shots: 25556
Number of signals: 740


### Querying the Metadatabase

Query the shot table and shot the results

In [3]:
# Find shot IDs
query = (
    session.query(metadata.tables['shots'])
)

result = pd.read_sql(query.statement, con=engine.connect())
result

Unnamed: 0,shot_id,timestamp,reference_shot,scenario,current_range,heating,divertor_config,pellets,plasma_shape,rmp_coil,...,cpf_vol_ipmax,cpf_vol_max,cpf_vol_truby,cpf_wmhd_ipmax,cpf_wmhd_max,cpf_wmhd_truby,cpf_zeff_ipmax,cpf_zeff_max,cpf_zeff_truby,cpf_zmag_efit
0,11695,2004-12-13 11:54:00+00:00,,,,,Conventional,False,,,...,,,,,,,,,,
1,11696,2004-12-13 12:07:00+00:00,,,,,Conventional,False,,,...,,,,,,,,,,
2,11697,2004-12-13 12:19:00+00:00,,,,,Conventional,False,,,...,,,,,,,,,,
3,11698,2004-12-13 12:31:00+00:00,,,,,Conventional,False,,,...,,,,,,,,,,
4,11699,2004-12-13 12:45:00+00:00,,,,,Conventional,False,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25551,47202,2023-03-31 14:59:33+00:00,45847.0,,,,,False,,False,...,,,,,,,,,,
25552,47203,2023-03-31 15:19:43+00:00,,,,,,False,,False,...,,,,,,,,,,
25553,47204,2023-03-31 15:38:20+00:00,,,,,,False,,False,...,,,,,,,,,,
25554,47205,2023-03-31 15:52:21+00:00,44539.0,,,,,False,,False,...,,,,,,,,,,


In [9]:
scenarios = metadata.tables['scenarios']

qscenarios = (
    session.query(scenarios)
)
qscenarios = pd.read_sql(qscenarios.statement, con=engine.connect())
qscenarios

Unnamed: 0,id,name
0,1,S1
1,2,S8
2,3,S6
3,4,S5
4,5,S4
5,6,S3
6,7,S2
7,8,S7
8,9,DN-450-CD-OH
9,10,DN-750-CD-1BSW


Filtering with CPF Summary Data and shot IDS

In [10]:
shots = metadata.tables['shots']

qshots = (
    session.query(shots)
          .filter(shots.c.scenario == 8)
          # .filter(shots.c.shot_id <= 30131)
)
qshots = pd.read_sql(qshots.statement, con=engine.connect())
qshots

Unnamed: 0,shot_id,timestamp,reference_shot,scenario,current_range,heating,divertor_config,pellets,plasma_shape,rmp_coil,...,cpf_vol_ipmax,cpf_vol_max,cpf_vol_truby,cpf_wmhd_ipmax,cpf_wmhd_max,cpf_wmhd_truby,cpf_zeff_ipmax,cpf_zeff_max,cpf_zeff_truby,cpf_zmag_efit
0,27396,2011-11-09 10:24:00+00:00,26531.0,8,700 kA,SS Beam,Conventional,False,,,...,,,,,,,,,,
1,27398,2011-11-09 10:55:00+00:00,27251.0,8,700 kA,SS Beam,Conventional,False,Connected Double Null,False,...,,,,,,,,,,
2,27662,2011-12-01 13:07:00+00:00,27403.0,8,700 kA,"2 Beams,SS Beam,SW Beam",Conventional,False,Connected Double Null,False,...,,,,,,,,,,
3,27663,2011-12-01 13:23:00+00:00,27403.0,8,700 kA,"2 Beams,SS Beam,SW Beam",Conventional,False,Connected Double Null,False,...,,,,,,,,,,
4,27845,2011-12-12 13:31:00+00:00,27838.0,8,,SS Beam,Conventional,False,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118,29928,2013-08-20 16:18:00+00:00,29222.0,8,700 kA,SS Beam,Conventional,False,Connected Double Null,False,...,,,,,,,,,,
119,29929,2013-08-20 16:35:00+00:00,29222.0,8,700 kA,SS Beam,Conventional,False,Connected Double Null,False,...,,,,,,,,,,
120,29930,2013-08-20 16:52:00+00:00,29222.0,8,700 kA,Ohmic,Conventional,False,,,...,,,,,,,,,,
121,29931,2013-08-20 17:08:00+00:00,29222.0,8,700 kA,SS Beam,Conventional,False,Connected Double Null,False,...,,,,,,,,,,


A more advanced query. Here we: 
 - Find shots with a given CPF value
 - Find corresponding signals
 - Filter signals by name

In [16]:
shots = metadata.tables['shots']
signals = metadata.tables['signals']
shot_signal_link = metadata.tables['shot_signal_link']

# Query all shots with zmag_efit > .04
qshots = (
    session.query(shots)
          .filter(shots.c.shot_id >= 27396)
          .filter(shots.c.shot_id <= 30036)
)
qshots = pd.read_sql(qshots.statement, con=engine.connect())
shot_ids = qshots['shot_id'].values
shot_ids = list(map(str, shot_ids))

# Query for corresponding signal IDs
qshot_signal = (
    session.query(shot_signal_link.c.signal_id)
    .filter(shot_signal_link.c.shot_id.in_(qshots.shot_id))
    .distinct()
)
qshot_signal = pd.read_sql(qshot_signal.statement, con=engine.connect())

# Query for signal data, filter only names containing 'acd'
qsignal = (
    session.query(signals)
    .filter(signals.c.signal_id.in_(qshot_signal.signal_id))
    .filter(signals.c.name.contains('AMC'))
)

result = pd.read_sql(qsignal.statement, con=engine.connect())
result

Unnamed: 0,signal_id,name,units,rank,dim_1_label,dim_2_label,dim_3_label,uri,description,signal_type,quality,doi,camera_metadata,camera
0,405,AMC_EFPS CURRENT,kA,1,,,,/home/ir-jack5/data/AMC_EFPS CURRENT.zarr,EFPS Current,Analysed,Not Checked,,,
1,406,AMC_ERROR FIELD_02,kA.turn,1,,,,/home/ir-jack5/data/AMC_ERROR FIELD_02.zarr,Error Field/02,Analysed,Not Checked,,,
2,407,AMC_ERROR FIELD_05,kA.turn,1,,,,/home/ir-jack5/data/AMC_ERROR FIELD_05.zarr,Error Field/05,Analysed,Not Checked,,,
3,408,AMC_P2IL COIL CURRENT,kA.turn,1,,,,/home/ir-jack5/data/AMC_P2IL COIL CURRENT.zarr,P2IL Coil Current,Analysed,Not Checked,,,
4,409,AMC_P2IL FEED CURRENT,kA,1,,,,/home/ir-jack5/data/AMC_P2IL FEED CURRENT.zarr,P2il Feed Current,Analysed,Not Checked,,,
5,410,AMC_P2IU COIL CURRENT,kA.turn,1,,,,/home/ir-jack5/data/AMC_P2IU COIL CURRENT.zarr,P2IU Coil Current,Analysed,Not Checked,,,
6,411,AMC_P2IU FEED CURRENT,kA,1,,,,/home/ir-jack5/data/AMC_P2IU FEED CURRENT.zarr,P2iu Feed Current,Analysed,Not Checked,,,
7,412,AMC_P2L CASE CURRENT,kA,1,,,,/home/ir-jack5/data/AMC_P2L CASE CURRENT.zarr,P2L Case Current,Analysed,Not Checked,,,
8,413,AMC_P2L CURRENT,kA.turn,1,,,,/home/ir-jack5/data/AMC_P2L CURRENT.zarr,P2L Current,Analysed,Not Checked,,,
9,414,AMC_P2OL COIL CURRENT,kA.turn,1,,,,/home/ir-jack5/data/AMC_P2OL COIL CURRENT.zarr,P2OL Coil Current,Analysed,Not Checked,,,


### Loading data

Here is an example of loading the data found in the database into a dataset and plotting some time series

In [27]:
pip install paramiko

Collecting paramiko
  Downloading paramiko-3.2.0-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.2/224.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting bcrypt>=3.2 (from paramiko)
  Downloading bcrypt-4.0.1-cp36-abi3-manylinux_2_28_x86_64.whl (593 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m593.7/593.7 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
Collecting pynacl>=1.5 (from paramiko)
  Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m856.7/856.7 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
Installing collected packages: bcrypt, pynacl, paramiko
Successfully installed bcrypt-4.0.1 paramiko-3.2.0 pynacl-1.5.0
Note: you may need to restart the kernel to use updated packages.


In [29]:
uri = result.uri.iloc[0]
uri

import zarr

zarr.storage.FSStore('ssh://pearl008@ui:/tmp/test.zarr')

SSHException: No authentication methods available

In [14]:
# Read data 
data = {row['name']: open_datatree(row.uri, engine='zarr') for index, row in result[['name', 'uri']].iterrows()}

dataset = data['AMC_PLASMA_CURRENT']
# Choose only relevant shots
dataset = dataset.filter(lambda x: x.name in shot_ids)
dataset = dataset.sel(time=slice(0, .5))

for shot_id, shot in dataset.items():
    data = shot['data']
    time = shot['time']
    plt.plot(time, data, label=f'Shot {shot_id}') 

plt.ylabel(f'{shot.label} ({shot.units})')
plt.xlabel('Time')
plt.legend()

GroupNotFoundError: group not found at path ''

Another example with the same data but with multi dimensional data this time. Use EFM PSI which should be an equillibrium reconstruction.

In [12]:
# Query for signal data, filter only names containing 'acd'
qsignal = (
    session.query(signals)
    .filter(signals.c.name.contains('EFM_PSI(R,Z)'))
)

result = pd.read_sql(qsignal.statement, con=engine.connect())
result

Unnamed: 0,signal_id,name,units,rank,dim_1_label,dim_2_label,dim_3_label,uri,description,signal_type,quality,doi,camera_metadata,camera
0,882,"EFM_PSI(R,Z)",Wb/rad,3,,,,/home/lhs18285/git/fair-mast/data/mast/zarr/EF...,"psi(r,z)",Analysed,Not Checked,,,


In [8]:
# Read data 
data = {row['name']: open_datatree(row.uri, engine='zarr') for index, row in result[['name', 'uri']].iterrows()}

# Choose only relevant shots
dataset = data['EFM_PSI(R,Z)']
dataset = dataset.filter(lambda x: x.name in shot_ids)
dataset = dataset.isel(time=50)

n_shots = len(dataset)
fig, axes = plt.subplots(2, 3, figsize=(10, 5))
axes = axes.flatten()

for index, (shot_id, shot) in enumerate(dataset.items()):
    data = shot['data']
    time = shot['time']
    axes[index].matshow(data, cmap='plasma')
    axes[index].set_xticks([], [])
    axes[index].set_yticks([], [])
    axes[index].set_title(f'Shot {shot_id}')
plt.suptitle(f'EFM_PSI(R,Z) ({shot.attrs["units"]})')
plt.tight_layout()

IndexError: index 50 is out of bounds for axis 0 with size 16