In [None]:
import os
import pathlib
from pathlib import Path

if Path(os.getcwd()).match("notebooks"):
    os.chdir("..")
os.getcwd()

In [None]:
import sys
from typing import Tuple, Optional

from typeguard import typechecked

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import scipy.signal

import torch

import gw_data
from preprocessor_meta import raw_meta
from models.cnn1d import CqtInputLayer

In [None]:
DATA_DIR = Path('g2net-data-000')

In [None]:
# Suitable for a 2020ish MacBook Pro
plt.rcParams['figure.dpi']= 140

SMALL_FONT_SIZE = 6
MEDIUM_FONT_SIZE = 8
BIGGER_FONT_SIZE = 10

plt.rc('font', size=SMALL_FONT_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_FONT_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_FONT_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_FONT_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_FONT_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_FONT_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_FONT_SIZE)  # fontsize of the figure title

In [None]:
SIGNAL_COLORS = ['red', 'green', 'blue']

def plot_filter_line(ax, sigs, idx, left: int = 0, right: Optional[int] = None):
    right = right or len(sigs[idx])
    
    ax.minorticks_on()
    ax.grid(which='major', color='#555555', linestyle='-', linewidth=0.7)
    ax.grid(which='minor', color='#AAAAAA', linestyle=':', linewidth=0.5)
    ax.set_axisbelow(False)

    ax.plot(FILTER_TIMES[left:right],
            sigs[idx][left:right],
            SIGNAL_COLORS[idx])

@typechecked
def plot_filter_sigs(_id: str, sigs: np.ndarray, left: int = 0, right: Optional[int] = None):
    fig, axs = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=[6, 5])
    for i in range(3):
        plot_filter_line(axs[i], sigs, i, left, right)
    fig.suptitle(f'id={_id}')

In [None]:
dtype=torch.float
cqt_layer = CqtInputLayer(raw_meta.output_shape)
cqt_layer.to("cpu", dtype=dtype)

@typechecked
def cqt_sig(sig: np.ndarray) -> np.ndarray:
    return cqt_layer.forward(torch.tensor(sig, dtype=dtype)).numpy()[0]

In [None]:
@typechecked
def plot_sig_q(sig: np.ndarray):
    if sig.shape != raw_meta.output_shape[1:]:
        raise ValueError(f"expected shape {raw_meta.output_shape[1:]}; got {sig.shape}")
    spec = cqt_sig(sig)
    spec = spec - np.min(spec)
    spec = spec / np.max(spec)
    spec = spec * 15
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=[6, 2])
    ax.pcolormesh(cqt_layer.times, cqt_layer.freqs, spec, vmax=15, vmin=0, cmap='viridis', shading="nearest")
    ax.minorticks_on()
    ax.grid(which='major', color='#DDDDDD', linestyle='-', linewidth=0.7)
    ax.grid(which='minor', color='#CCCCCC', linestyle=':', linewidth=0.5)
    ax.set_axisbelow(False)
    plt.show()

In [None]:
test_id = '000a218fdd'
test_sigs = np.load(gw_data.train_file(DATA_DIR, test_id))

In [None]:
for i in range(gw_data.N_SIGNALS):
    plot_sig_q(test_sigs[i])