# Apply a U-Net with the size of the generator to solve the task directly

In [None]:
from typing import Dict, List, Union

import torch

from src.data.dataloader import get_loader
from src.data.dataset import SpeechData
from src.data.files import write_model
from src.data.filesampler import sample_filepaths
from src.eval.testing import Tester
from src.networks.unet_1D import UNet_1D
from src.util.consts import LOG_INTERVAL, TEST_TASK_1
from src.util.logger import Logger
from src.util.signals import load_chunks_pair_list


class Unet:
    def __init__(
        self,
        levels: List[str],
        hyperparameters: Dict[str, Union[int, float]],
        device: torch.device,
        val_paths: List[str] = None,
    ) -> None:
        self.levels = levels
        self.hyperparameters = hyperparameters
        self.device = device
        self.tags = []

        # Initialize dataset
        self.val_paths = (
            sample_filepaths(tasks=levels, sample_rate=0.005)
            if val_paths is None
            else val_paths
        )
        self.val_chunks = load_chunks_pair_list(self.val_paths)

        dataset = SpeechData(tasks=levels, ignore_paths=self.val_paths)

        self.train_loader = get_loader(
            dataset, batch_size=hyperparameters["batch_size"], device=device
        )

        # Initialize networks
        self.network = UNet_1D(device=device)

        # Compile models if on GPU
        if device.type == "cuda":
            torch.set_float32_matmul_precision("high")
            self.network = torch.compile(self.network)
            print("Compiled network!")

        # Initialize optimizers
        self.optimizer = torch.optim.AdamW(
            self.network.parameters(), lr=hyperparameters["lr"]
        )

        self.reconstruction_loss = torch.nn.MSELoss()


    def learn(
        self,
        num_episodes: int = 2000,
        sweep: bool = False,
        disable_testing: bool = False,
    ):
        # Initialize logger
        self.logger = Logger(
            tasks=self.levels,
            hyperparameters=self.hyperparameters,
            tags=["segan"] + self.tags,
            sweep=sweep,
            val_paths=self.val_paths,
        )
        self.tester = Tester(
            run_name=self.logger.run_name,
            paths=self.val_paths,
            chunks=self.val_chunks,
            device=self.device,
            disable_testing=disable_testing,
            write_all=True,
        )

        for episode in range(num_episodes):
            for i, (clean_files, recorded_files, _) in enumerate(self.train_loader):
                clean_files = clean_files.to(self.device, non_blocking=True)
                recorded_files = recorded_files.to(self.device, non_blocking=True)

                # Update generator
                self.network.zero_grad(set_to_none=True)
                g_out = self.network(recorded_files)
                generator_loss = self.reconstruction_loss(g_out, clean_files)

                generator_loss.backward()
                self.optimizer.step()

                if i % LOG_INTERVAL == 0:
                    self.logger.log_metrics(
                        generator_loss=generator_loss.item(),
                        episode=episode,
                        iteration=i,
                        lr=0.0001,
                    )

            if episode % 80 == 0:
                (
                    chunk_recon_loss,
                    mean_cer,
                    cers,
                    sample_paths,
                    transcriptions,
                ) = self.tester.test(
                    generator=self.generator,
                    episode=episode,
                )

                self.logger.log_metrics(
                    chunk_recon_loss=chunk_recon_loss,
                    mean_cer=mean_cer,
                    cers=cers,
                    episode=episode,
                    iteration=i,
                    lr=self.gen_scheduler.get_last_lr()[0],
                    audio_paths=sample_paths,
                    transcriptions=transcriptions,
                )

                self.write(episode=episode)

        self.logger.finish()

    def test(self):
        # Validation result
        val_results = self.tester.test(
            generator=self.generator,
            paths=self.val_paths,
            chunks=self.val_chunks,
            episode="final_validation",
            device=self.device,
            write_all=True,
        )

        # Test results
        self.test_paths = sample_filepaths(TEST_TASK_1, sample_rate=1)
        test_chunks = load_chunks_pair_list(sampled_paths=self.test_paths)
        test_results = self.tester.test(
            generator=self.generator,
            paths=self.test_paths,
            chunks=test_chunks,
            episode="final_testing",
            device=self.device,
            write_all=True,
        )

        return val_results, test_results

    def write(self, episode: int = None):
        name_prefix = f"episode_{episode}_" if episode is not None else ""

        write_model(
            model=self.generator,
            run_name=self.logger.run_name,
            model_name=f"{name_prefix}generator",
        )
        write_model(
            model=self.discriminator,
            run_name=self.logger.run_name,
            model_name=f"{name_prefix}discriminator",
        )
        write_model(
            self.generator_optimizer,
            run_name=self.logger.run_name,
            model_name=f"{name_prefix}generator_optimizer",
        )
        write_model(
            self.discriminator_optimizer,
            run_name=self.logger.run_name,
            model_name=f"{name_prefix}discriminator_optimizer",
        )
