In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
from pathlib import Path
import xarray as xr
import tensorflow as tf
import datetime
from shared.models import *
from shared.generators import SAT1DataGenerator, NewSAT1DataGenerator
from shared.utilities import *
from sklearn.metrics import classification_report
import random
random.seed(42)
%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env TF_GPU_ALLOCATOR=cuda_malloc_async

2023-08-28 13:07:27.035724: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TF_GPU_ALLOCATOR=cuda_malloc_async


### Set up data

In [2]:
data_path = Path("data/sat1/split_stage_data.nc")

data = xr.load_dataset(data_path)

In [3]:
participants = data.participant.values.tolist()
test_participants = random.sample(participants, 5)
train_participants = [p for p in participants if p not in test_participants]

In [4]:
test_data = data.sel(participant=test_participants)
train_data = data.sel(participant=train_participants)

# train_data = data.sel(
#     participant=[
#         "0021",
#         "0022",
#         "0023",
#         "0024",
#     ]
# )

# test_data = data.sel(
#     participant=[
#         "0025"
#     ]
# )

In [5]:
batch_size = 16
train_gen = NewSAT1DataGenerator(train_data, batch_size=batch_size)
test_gen = NewSAT1DataGenerator(test_data, batch_size=batch_size)

# tf.config.optimizer.set_experimental_options({"layout_optimizer": False})

In [6]:
model = SAT1Base(30, 157, 4)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 30, 157, 1)]      0         
                                                                 
 conv2d (Conv2D)             (None, 30, 153, 64)       384       
                                                                 
 max_pooling2d (MaxPooling2  (None, 30, 76, 64)        0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 30, 74, 128)       24704     
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 30, 37, 128)       0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 30, 35, 256)       98560 

2023-08-28 13:07:32.578206: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] could not open file to read NUMA node: /sys/bus/pci/devices/0000:07:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-08-28 13:07:32.607275: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] could not open file to read NUMA node: /sys/bus/pci/devices/0000:07:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-08-28 13:07:32.607376: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] could not open file to read NUMA node: /sys/bus/pci/devices/0000:07:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-08-28 13:07:32.609674: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] could not open file to read NUMA node: /sys/bus/pci/devices/0000:07:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-08-28 13:07:32.609766: I tensorflow/compile

In [7]:
run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
path = Path("logs/") / run_id
to_write = {"Model summary": get_summary_str(model), "Test": "test:)"}
fit = model.fit(
    train_gen,
    epochs=5,
    callbacks=[earlyStopping_cb, LoggingTensorBoard(to_write, log_dir=path)],
    # TODO: Create val_gen
    validation_data=test_gen,
    use_multiprocessing=True,
    workers=8,
)

Epoch 1/5


2023-08-28 13:07:39.621446: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
2023-08-28 13:07:41.096259: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-08-28 13:07:41.104468: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5572e46f1230 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-08-28 13:07:41.104495: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2023-08-28 13:07:41.108113: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-08-28 13:07:41.214699: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the p

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [42]:
# Run before re-training to clear up VRAM
import gc

gc.collect()
tf.keras.backend.clear_session()
del model

In [30]:
test_gen.cat_labels

['confirmation', 'decision', 'encoding', 'response']

In [41]:
print("Testset results")
predicted_classes = np.argmax(model.predict(test_gen), axis=1)
predicted_classes = [test_gen.cat_labels[idx] for idx in list(predicted_classes)]
print(classification_report(test_gen.full_labels, predicted_classes))
# print(test_gen.categories)

Testset results
              precision    recall  f1-score   support

confirmation       0.73      0.63      0.68       441
    decision       0.72      0.81      0.76       844
    encoding       0.88      0.89      0.89       869
    response       0.84      0.80      0.82       870

    accuracy                           0.80      3024
   macro avg       0.80      0.78      0.79      3024
weighted avg       0.81      0.80      0.80      3024

