# Reconstruction Error on the validation and test sets

In [None]:
import os
import torch
from src.util.device import set_device
from torch import nn
from src.data.filesampler import sample_filepaths
from src.util.consts import TEST_TASK_1
from src.util.signals import load_chunks_pair_list

PATH = ""

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = set_device()

In [None]:
test_paths = sample_filepaths(TEST_TASK_1, sample_rate=1)
with open('../val_paths.txt', "r") as f:
            val_paths = [line.split(',') for line in f.read().splitlines()]
val_chunks = load_chunks_pair_list(val_paths)
test_chunks = load_chunks_pair_list(sampled_paths=test_paths)

In [None]:
loss = nn.MSELoss()

def test(chunks, model):
    recorded_chunks = torch.cat([chunk[1] for chunk in chunks], dim=0)
    clean_chunks = torch.cat([chunk[0] for chunk in chunks], dim=0)

    chunk_recon_loss = 0

    # Pass all recorded chunks through generator at once
    all_result_chunks = torch.cat(
        [
            model(recorded_chunks[i * 128 : (i + 1) * 128].to(device)).detach().to('cpu')
            for i in range(recorded_chunks.shape[0] // 128 + 1)
        ],
        dim=0,
    )

    for (clean_chunks, recorded_chunks) in chunks:
        # Get result chunks for current file and remove them from chunk list
        result_chunks = all_result_chunks[: len(recorded_chunks)]
        all_result_chunks = all_result_chunks[len(recorded_chunks) :]

        # Calculate loss for chunks
        chunk_recon_loss += loss(result_chunks, clean_chunks) / len(chunks)

    return chunk_recon_loss

In [None]:
from src.segan import SEGAN


dirs = [dir for dir in os.listdir() if os.path.isdir(dir)]

val_errors = []
test_errors = []

segan = SEGAN()
for dir in dirs:
    # Load generator
    path = os.path.join(PATH, dir)
    path = os.path.join(path, f"{dir}_generator.pt")
    segan.load(path)

    val_result = test(val_chunks, segan.generator)
    test_result = test(test_chunks, segan.generator)

    val_errors.append(val_result)
    test_errors.append(test_result)

    print(f"Calculated error for {dir}!")
    

In [None]:
output_path = os.path.join(PATH, "reconstruction.txt")
f = open(output_path, "w")
f.write(str(val_errors))
f.write("\n")
f.write(str(test_errors))
f.close()