In [44]:
import os
import yaml
import torch
from clean.data import DeepCleanInferenceDataset
from clean.infer import OnlineInference
from clean.model import InferenceModel

In [26]:
os.getcwd()

'/home/chiajui.chou/deepcleanv2/projects/clean'

In [27]:
# Loading configuration from YAML
clean_config_file = "/home/chiajui.chou/deepcleanv2/projects/clean/config_clean.yaml"
with open(clean_config_file, 'r') as file:
    clean_config = yaml.safe_load(file)
print(clean_config.keys())

train_config_file = os.path.join(clean_config['train_dir'], "config.yaml")
with open(train_config_file, 'r') as file:
    train_config = yaml.safe_load(file)
# print(train_config.keys())
print(train_config['data'].keys())
print(train_config['data']['inference_sampling_rate'])

# train_hparams_file = os.path.join(clean_config['train_dir'], "hparams.yaml")
# with open(train_hparams_file, 'r') as file:
#     train_hparams = yaml.safe_load(file)
# print(train_hparams.keys())

# Initialize InferenceModel
model = InferenceModel(
    clean_config['train_dir'],
    clean_config['sample_rate'],
    clean_config['device'],
)

# Initialize DeepCleanInferenceDataset
inference_dataset = DeepCleanInferenceDataset(
    hoft_dir=clean_config['hoft_dir'],
    witness_dir=clean_config['witness_dir'],
    model=model,
    device=clean_config['device'],
)

# Initialize OnlineInference
online_inference = OnlineInference(
    dataset=inference_dataset,
    model=model,
    outdir=clean_config['outdir'],
    device=clean_config['device']
)


dict_keys(['device', 'sample_rate', 'hoft_dir', 'witness_dir', 'outdir', 'train_dir'])
dict_keys(['fname', 'channels', 'kernel_length', 'freq_low', 'freq_high', 'batch_size', 'train_duration', 'test_duration', 'valid_frac', 'train_stride', 'inference_sampling_rate', 'start_offset', 'filt_order'])
2.0


In [85]:
# online_inference.model.model

In [28]:
# Run the online inference process for a number of iterations (e.g., 100)
# for k in range(300):
#     online_inference.predict_and_write()
#     online_inference.dataset.update()
#     print(f"iteration {k}")

### InferenceDataset

In [29]:
# Initialize DeepCleanInferenceDataset
inference_dataset = DeepCleanInferenceDataset(
    hoft_dir=clean_config['hoft_dir'],
    witness_dir=clean_config['witness_dir'],
    model=model,
    device=clean_config['device'],
)

print(inference_dataset.t0)
print(inference_dataset.kernel_size)
print("X_inference:")
print(inference_dataset.X_inference.kernel_size)
print(inference_dataset.X_inference.stride)
print("y_inference:")
print(inference_dataset.y_inference.kernel_size)
print(inference_dataset.y_inference.stride)

# print(inference_dataset.X_inference.X.shape)
# print(inference_dataset.y_inference.X.shape)
for i in inference_dataset.X_inference:
    print(i.shape)

1250916847
4096
X_inference:
4096
2048
y_inference:
4096
2048
torch.Size([5, 21, 4096])


In [42]:
print(inference_dataset.t0)
X_dset = inference_dataset.X_inference
y_dset = inference_dataset.y_inference
for X, y in zip(X_dset, y_dset):
    print(X.shape)
    print(X[0,0,:])
    print(y.shape)
    print(y[0,:])

# inference_dataset.update()
# print(inference_dataset.t0)

1250916847
torch.Size([5, 21, 4096])
tensor([0.6945, 0.7658, 0.7648,  ..., 0.3913, 0.3912, 0.3911], device='cuda:0')
torch.Size([5, 4096])
tensor([-1.8676e-19, -2.1235e-19, -1.8091e-19,  ...,  1.2468e-19,
         1.4771e-19,  1.4411e-19], device='cuda:0')


### Model

In [87]:
device = clean_config['device']
# print(clean_config)
train_dir = clean_config['train_dir']
t_model = torch.jit.load(os.path.join(train_dir, "model.pt")).to(device)
print(type(t_model))

<class 'torch.jit._script.RecursiveScriptModule'>


### Prediction

In [88]:
witness = next(iter(X_dset))
print(witness.device)
print(witness.shape)

pred = t_model(witness)
print(pred.shape)

cuda:0
torch.Size([5, 21, 4096])
torch.Size([5, 4096])


In [92]:
print(train_config['model']['metric']['init_args'])
inrerence_sampling_rate = train_config['data']['inference_sampling_rate']

{'edge_pad': 0.2, 'filter_pad': 0.8}
2.0
