In [None]:
# from utils.compression import compress_directories, decompress_directories

# base_path = "downloaded_data/gwilliams2023"
# destination_path = "data/gwilliams2023"

# decompress_directories(
#     base_path,
#     destination_path,
#     checksum_file_name="checksums.txt",
#     delete_compressed_files=True,
#     num_workers=None
# )

In [None]:
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

config = SimpleConvConfig(
    # Str to list of possible conditions
    conditions=None,
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=256,
    dropout=0.3,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=False,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=0.0,
    merger_conditional=None,
    # Inital
    initial_linear=256,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=4,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

In [None]:
import gc
import random
import time
from librosa import cache
from regex import E
from tqdm import tqdm
from config.simpleconv_config import SimpleConvConfig
from models.simpleconv import SimpleConv
from studies.study_factory import StudyFactory
import typing as tp
import json
from itertools import product
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import os
import logging
import shutil
import torch

from .dataloader import DataLoader
from .dataloader.audio_batch import AudioBatch
from .losses.clip import CLIPLoss
from config import SimpleConvConfig, Config
from train.training_session import TrainingSession

device = "cuda"


class TrainingConfigV0(Config):
    def __init__(
        self,
        brain_encoder_config: SimpleConvConfig,
        data_partition: tp.Dict[str, tp.Dict[str, tp.List[str]]],
        # Pre-processing parameters
        # Brain
        new_freq: int = 100,
        frequency_bands: tp.Tuple[str, tp.Tuple[float, float]] = {"all": (0.5, 100)},
        max_random_shift: float = 2.0,
        window_size: int = 4,
        window_stride: int = 1,
        brain_clipping: float = 20,
        baseline_window: int = 0.5,
        notch_filter: bool = True,
        scaling: str = "minmax",
        # Audio
        audio_model: str = "openai/whisper-large-v3",
        audio_sample_rate: int = 16000,
        hop_length: int = 160,
        # Hyperparameters
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-4,
        epochs: int = 50,
        batch_size: int = 128,
        use_clip_loss: bool = True,
        use_mse_loss: bool = True,
        alpha: float = 0.5,
        random_test_size: int = 3,
        seed: int = 42,
    ):
        self.brain_encoder_config = brain_encoder_config
        # key: study_name, value: dict with keys: "testing_subjects", "testing_tasks",
        # where each value is a list of int. Ones not specified in either lists are
        # used for training.
        self.data_partition = data_partition

        # Pre-processing parameters
        # Brain
        self.new_freq = new_freq
        self.frequency_bands = frequency_bands
        self.max_random_shift = max_random_shift
        self.window_size = window_size
        self.window_stride = window_stride
        self.baseline_window = baseline_window
        self.notch_filter = notch_filter
        self.brain_clipping = brain_clipping
        self.scaling = scaling

        # Audio
        self.audio_model = audio_model
        self.audio_sample_rate = audio_sample_rate
        self.hop_length = hop_length

        # Hyperparameters
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.epochs = epochs
        self.batch_size = batch_size
        self.use_clip_loss = use_clip_loss
        self.use_mse_loss = use_mse_loss
        self.alpha = alpha
        self.random_test_size = random_test_size
        self.seed = seed

        assert 0 <= self.alpha <= 1, "Alpha must be between 0 and 1"
        assert use_clip_loss or use_mse_loss, "At least one loss function must be used"

    # does not overide parent method
    def to_dict_(self):
        brain_encoder_config = self.brain_encoder_config.to_dict()
        config = self.to_dict()
        config["brain_encoder_config"] = brain_encoder_config
        return config

    def from_dict(self, config: tp.Dict[str, tp.Any]):
        self.brain_encoder_config = SimpleConvConfig.from_dict(
            config["brain_encoder_config"]
        )
        self.data_partition = config["data_partition"]
        self.new_freq = config["new_freq"]
        self.frequency_bands = config["frequency_bands"]
        self.max_random_shift = config["max_random_shift"]
        self.window_size = config["window_size"]
        self.window_stride = config["window_stride"]
        self.baseline_window = config["baseline_window"]
        self.notch_filter = config["notch_filter"]
        self.brain_clipping = config["brain_clipping"]
        self.scaling = config["scaling"]
        self.audio_model = config["audio_model"]
        self.audio_sample_rate = config["audio_sample_rate"]
        self.hop_length = config["hop_length"]
        self.learning_rate = config["learning_rate"]
        self.weight_decay = config["weight_decay"]
        self.epochs = config["epochs"]
        self.batch_size = config["batch_size"]
        self.use_clip_loss = config["use_clip_loss"]
        self.use_mse_loss = config["use_mse_loss"]
        self.alpha = config["alpha"]
        self.random_test_size = config["random_test_size"]
        self.seed = config["seed"]
        return self


class TrainingSessionV0(TrainingSession):
    def __init__(
        self,
        config: TrainingConfigV0 = None,
        studies: tp.Dict[str, str] = None,
        data_path: str = "/home/ubuntu/brain-decoding/data",
        save_path: str = "/home/ubuntu/brain-decoding/saves",
        clear_cache: bool = False,
        cache_enabled: bool = True,
        max_cache_size: int = 100,
    ):
        """Initializes a training session with the provided configuration and data.
        This version deals with audio batches.

        Arguments:
            config -- The configuration for the training session.
            studies -- dict of studies, batch type. Partition policy determined in TrainingConfig
                    Batch type determines how to load data from study.

            data_path -- The path to the data directory.
            save_path -- The path to the directory where the model and logs will be saved.
            clear_cache -- Whether to clear the cache for the studies.
            cache_enabled -- Whether to enable caching for the studies.
            max_cache_size -- The maximum number of stimulis in cache.
        """

        super().__init__(
            config=config,
            studies=studies,
            data_path=data_path,
            save_path=save_path,
            clear_cache=clear_cache,
            cache_enabled=cache_enabled,
            max_cache_size=max_cache_size,
        )

        # Set conditions
        if self.config.brain_encoder_config.conditions:

            if "study" in self.config.brain_encoder_config.conditions:
                self.config.brain_encoder_config.conditions["study"] = list(
                    studies.keys()
                )

            if "subjects" in self.config.brain_encoder_config.conditions:
                subjects = set()
                for recording in self.recordings:
                    subjects.add(f"{recording.study_name}_{recording.subject_id}")
                self.config.brain_encoder_config.conditions["subjects"] = list(subjects)

        self.model = SimpleConv(self.config.brain_encoder_config)

        self.optimizer = AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
        )
        self.clip_loss, self.mse_loss = CLIPLoss(), torch.nn.functional.mse_loss()

    def train(
        self,
        device: str,
        buffer_size: int,
        num_workers: int,
        max_cache_size: int,
        current_epoch: int = 0,
    ):

        # Set all training parameters
        self.device = device
        gpu_ok = False
        torch.set_float32_matmul_precision("high")
        training_size = len(self.dataset["train"])
        self.scaler = GradScaler()
        self.model.to(device)
        self.clip_loss.to(device)

        # Check if GPU is NVIDIA V100, A100, or H100
        if torch.cuda.is_available():
            device_cap = torch.cuda.get_device_capability()
            if device_cap in ((7, 0), (8, 0), (9, 0)):
                gpu_ok = True
        if not gpu_ok:
            self.log_print(
                "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected."
            )

        # Fetch recordings
        if self.dataloader is None:
            self.dataloader = self.get_dataloader(
                buffer_size=buffer_size,
                num_workers=num_workers,
                max_cache_size=max_cache_size,
            )

        for epoch in range(current_epoch + 1, self.config.epochs + 1):
            try:
                self.model.to(device).train()
                epoch_start_time = time.time()

                # Shuffle for each epoch, and start fetching
                epoch_training_dataset, remaining = (
                    self.dataset["train"].copy(),
                    training_size,
                )
                # For reproducibility
                self.set_seed(int(self.config.seed + epoch))
                random.shuffle(epoch_training_dataset)
                self.dataloader.start_fetching(epoch_training_dataset, cache=True)

            except Exception as e:
                self.log_print(f"Error in epoch {epoch} during initialization, {e}")
                self.save(f"error_epoch_{epoch}")

            # Run each batch
            while True:

                batch = self.dataloader.get_recording()
                if batch is None:
                    break

                try:
                    start_time = time.time()
                    results = self.run_batch(batch, train=True)
                    self.metrics["train"].append(results)

                    # Don't print, just log
                    self.logger.info(
                        f"Epoch {epoch}, Remaining {remaining}/{training_size}. Runtime {time.time() - start_time:.2f}s."
                    )
                    self.logger.info(
                        f'Loss: {results["loss"]:.4f}, Clip Loss: {results["clip_loss"]:.4f}, MSE Loss: {results["mse_loss"]:.4f}, Commitment Loss: {results["commitment_loss"]:.4f}'
                    )
                    self.logger.info(
                        f'Accuracy: {results["accuracy"]:.4f}, Top 1: {results["top_1_accuracy"]:.4f}, Top 5: {results["top_5_accuracy"]:.4f}, Top 10: {results["top_10_accuracy"]:.4f}, Perplexity: {results["perplexity"]:.4f}'
                    )
                    remaining -= 1
                except Exception as e:
                    # Do log errors
                    self.log_print(
                        f"Error in epoch {epoch}, {batch.recording.study_name} {batch.recording.subject_id} {batch.recording.session_id} {batch.recording.task_id}. Skipping."
                    )
                    continue

            elapsed_minutes = (time.time() - epoch_start_time) / 60
            self.log_print(
                f"Epoch {epoch} completed in {elapsed_minutes:.2f}m. {elapsed_minutes / training_size:.2f}m per recording."
            )

            # Testing
            try:
                self.log_print(f"Testing at epoch {epoch}")
                with torch.no_grad():
                    self.test(
                        buffer_size=buffer_size,
                        num_workers=num_workers,
                        max_cache_size=max_cache_size,
                    )
            except Exception as e:
                self.log_print(f"Error in epoch {epoch} during testing, {e}")
                self.save(f"error_epoch_{epoch}")
                raise e

            # Save model
            self.save(f"epoch_{epoch}")

        self.log_print("Training completed.")

    def run_batch(self, batch: AudioBatch, train: bool) -> tp.Dict[str, float]:
        """
        Per recording processing for training and testing. Returns average metrics
        and losses for the recording. Returns metrics on CPU.
        """

        # Some processing to ensure dims match
        brain_segments, audio_segments, recording = batch
        brain_segments, audio_segments = self.discard_nan(
            brain_segments["all"], audio_segments
        )

        # Models config decides if it is used
        conditions = {
            "study": str(recording.study_name),
            "subject": str(recording.study_name) + "_" + str(recording.subject_id),
        }

        # Shuffle segments
        shuffle_indices = torch.randperm(brain_segments.shape[0])
        brain_segments, audio_segments = (
            brain_segments[shuffle_indices].to(self.device),
            audio_segments[shuffle_indices].to(self.device),
        )  # [B, C, T], [B, mel_bins, T]

        # Process by specified batch size
        batch_indices = [
            (i, min(i + self.config.batch_size, total))
            for i in range(0, total, self.config.batch_size)
        ]

        # Initialize recording metrics
        (
            recording_loss,
            recording_clip_loss,
            recording_mse_loss,
            recording_commitment_loss,
        ) = (0, 0, 0, 0)

        (total, missed_recordings, missed_batches) = (
            brain_segments.shape[0],
            0,
            0,
        )
        (
            recording_correct,
            recording_top_1,
            recording_top_5,
            recording_top_10,
            recording_perplexity,
            recording_temp,
        ) = (
            0,
            0,
            0,
            0,
            0,
            0,
        )

        with autocast(dtype=torch.bfloat16):

            for start, end in batch_indices:

                try:
                        
                    if train:
                        self.optimizer.zero_grad()

                    # Slice by batch
                    brain_batch, audio_batch = (
                        brain_segments[start:end],
                        audio_segments[start:end],
                    )

                    # Forward pass
                    (output, quantizer_metrics) = self.model(
                        x=brain_batch,
                        recording=recording,
                        conditions=conditions,
                        mel=audio_batch,
                        train=True,
                    )  # [B, C, T]

                    # Compute loss
                    mse_loss = self.mse_loss(
                        input=output, target=audio_batch, reduction="mean"
                    )
                    clip_results = self.clip_loss(x_1=output, x_2=audio_batch)
                    clip_loss, clip_metrics = clip_results["loss"], clip_results["metrics"]

                    # Sum loss based on config
                    if self.config.use_clip_loss and self.config.use_mse_loss:
                        loss = ((1 - self.config.alpha) * mse_loss) + (
                            self.config.alpha * clip_loss
                        )
                    elif not self.config.use_clip_loss and self.config.use_mse_loss:
                        loss = mse_loss
                    elif self.config.use_clip_loss and not self.config.use_mse_loss:
                        loss = clip_loss

                    if quantizer_metrics is not None:
                        if "commitment_loss" in quantizer_metrics:
                            loss += quantizer_metrics["commitment_loss"]

                    # Backward pass
                    if not torch.isnan(loss).any():

                        if train:
                            self.scaler.scale(loss).backward()
                            self.scaler.step(self.optimizer)
                            self.scaler.update()

                        # Store losses, move to CPU
                        recording_loss += loss.detach().to("cpu").item()
                        recording_clip_loss += clip_loss.detach().to("cpu").item()
                        recording_mse_loss += mse_loss.detach().to("cpu").item()

                        # Store metrics, already on CPU
                        recording_correct += clip_metrics["correct"]
                        recording_top_1 += clip_metrics["top_1_correct"]
                        recording_top_5 += clip_metrics["top_5_correct"]
                        recording_top_10 += clip_metrics["top_10_correct"]

                        # Quantizer metrics
                        if quantizer_metrics is not None:
                            if "perplexity" in quantizer_metrics:
                                perplexity = (
                                    quantizer_metrics["perplexity"]
                                    .detach()
                                    .to("cpu")
                                    .mean(dim=0)
                                )
                                recording_perplexity += perplexity.item()
                            if "temp" in quantizer_metrics:
                                recording_temp += (
                                    quantizer_metrics["temp"].detach().to("cpu").item()
                                )
                            if "commitment_loss" in quantizer_metrics:
                                recording_commitment_loss += (
                                    quantizer_metrics["commitment_loss"]
                                    .detach()
                                    .to("cpu")
                                    .item()
                                )
                    else:
                        self.logger.info(
                            f"Loss is NaN for {recording.study_name} {recording.subject_id} {recording.session_id} {recording.task_id}."
                        )
                        missed_recordings += end - start
                        missed_batches += 1
                        
                except Exception as e:
                    self.logger.info(
                        f"Error in processing {recording.study_name} {recording.subject_id} {recording.session_id} {recording.task_id}."
                    )
                    missed_recordings += end - start
                    missed_batches += 1
                    continue

        gc.collect()
        torch.cuda.empty_cache()

        # Correct for missed recordings and batches
        total -= missed_recordings
        batches = len(batch_indices) - missed_batches

        # Loss divided by batches, metrics by total
        return {
            "loss": recording_loss / batches if batches > 0 else 0,
            "clip_loss": recording_clip_loss / batches if batches > 0 else 0,
            "mse_loss": recording_mse_loss / batches if batches > 0 else 0,
            "commitment_loss": (
                recording_commitment_loss / batches if batches > 0 else 0
            ),
            "perplexity": recording_perplexity / batches if batches > 0 else 0,
            "accuracy": recording_correct / total,
            "top_1_accuracy": recording_top_1 / total,
            "top_5_accuracy": recording_top_5 / total,
            "top_10_accuracy": recording_top_10 / total,
        }

    def test(self, buffer_size: int, num_workers: int, max_cache_size: int):

        self.model.eval().to(self.device)
        self.set_seed(self.config.seed)
        test_start_time = time.time()

        test_datasets = {}

        # Create dataset and loader
        for test in self.dataset["test"].keys():
            # Randomly subsample recordings for each type of test
            if len(self.dataset["test"][test]) < self.config.random_test_size:
                test_datasets[test] = self.dataset["test"][test]
            else:
                test_datasets[test] = random.sample(
                    self.dataset["test"][test], self.config.random_test_size
                )

            if self.test_dataloader.get(test) is None:
                self.test_dataloader[test] = self.get_dataloader(
                    buffer_size=buffer_size,
                    num_workers=num_workers,
                    max_cache_size=max_cache_size,
                )
            self.test_dataloader[test].start_fetching(test_datasets[test], cache=True)

        test_sizes = {test: len(test_datasets[test]) for test in test_datasets.keys()}

        # Run tests
        for test in test_datasets.keys():

            while True:

                batch = self.test_dataloader[test].get_recording()
                if batch is None:
                    break

                try:

                    start_time = time.time()

                    results = self.run_batch(batch, train=False)
                    self.metrics["test"][test].append(results)

                    # Log results
                    self.logger.info(
                        f"Testing {test} {test_sizes[test]}/{len(test_datasets[test])}. Runtime {time.time() - start_time:.2f}s."
                    )
                    self.logger.info(
                        f'Loss: {results["loss"]:.4f}, Clip Loss: {results["clip_loss"]:.4f}, MSE Loss: {results["mse_loss"]:.4f}, Commitment Loss: {results["commitment_loss"]:.4f}'
                    )
                    self.logger.info(
                        f'Accuracy: {results["accuracy"]:.4f}, Top 1: {results["top_1_accuracy"]:.4f}, Top 5: {results["top_5_accuracy"]:.4f}, Top 10: {results["top_10_accuracy"]:.4f}, Perplexity: {results["perplexity"]:.4f}'
                    )

                except Exception as e:
                    self.log_print(
                        f"Error in testing {test}, {batch.recording.study_name} {batch.recording.subject_id} {batch.recording.session_id} {batch.recording.task_id}. Skipping."
                    )
                    test_sizes[test] -= 1
                    continue

        # Log info
        elapsed_minutes = (time.time() - test_start_time) / 60
        self.logger.info(f"Testing completed in {elapsed_minutes:.2f}m.")
        return

    def get_dataloader(self, buffer_size, num_workers, max_cache_size):
        dataloader = DataLoader(
            buffer_size=buffer_size,
            max_cache_size_gb=max_cache_size,
            cache_dir="cache",
            notch_filter=self.config.notch_filter,
            frequency_bands=self.config.frequency_bands,
            scaling=self.config.scaling,
            brain_clipping=self.config.brain_clipping,
            baseline_window=self.config.baseline_window,
            new_freq=self.config.new_freq,
            batch_types={"audio": num_workers},
            batch_kwargs={
                "audio": {
                    "max_random_shift": self.config.max_random_shift,
                    "window_size": self.config.window_size,
                    "window_stride": self.config.window_stride,
                    "audio_sample_rate": self.config.audio_sample_rate,
                    "hop_length": self.config.hop_length,
                    "audio_processor": self.config.audio_model,
                }
            },
        )
        return dataloader

    def discard_nan(
        self,
        brain: torch.Tensor,
        audio: torch.Tensor,
    ):
        """
        If any nan in brain or audio data, discard the batch.

        Arguments:
            brain -- The brain data tensor, [B, C, T]
            audio -- The audio data, [B, mel_bins, T]
        """

        valid_mask = ~(
            torch.isnan(brain).any(dim=(1, 2)) | torch.isnan(audio).any(dim=(1, 2))
        )

        if valid_mask.all():
            return brain, audio

        # Apply the same mask to both tensors
        filtered_brain = brain[valid_mask]
        filtered_audio = audio[valid_mask]

        if filtered_brain.shape[0] != filtered_audio.shape[0]:
            raise ValueError(
                "Filtered brain and audio data must have the same number of samples"
            )

        return filtered_brain, filtered_audio

    def pre_process_all_recordings(
        self, buffer_size: int, num_workers: int, max_cache_size: int
    ):
        """Pre-processes all data and saves as .pt in cache at once."""

        if self.recordings is None:
            self.partition_datasets()

        if self.dataloader is None:
            self.dataloader = self.get_dataloader(
                buffer_size, num_workers, max_cache_size
            )

        total_recordings, remaining = len(self.recordings), len(self.recordings)
        pbar = tqdm(total=total_recordings, desc="Loading recordings")

        self.dataloader.start_fetching(self.recordings)

        while True:
            recording = self.dataloader.get_recording()
            if recording is None:
                break
            remaining -= 1
            pbar.update(1)

    def save(self, name: str):
        """Saves the model and logs to the save path."""
        with torch.no_grad():

            # Training session config
            if not os.path.exists(self.save_path):
                os.makedirs(self.save_path)
                config = self.config.to_dict()
                with open(self.save_path + "/training_config.json", "w") as json_file:
                    json.dump(config, json_file, indent=4)

            checkpoint_path = f"{self.save_path}/{name}"

            # Save model
            torch.save(
                {
                    "config": self.config.to_dict(),
                    "model": self.model.cpu().state_dict(),
                    "conditions": self.config.brain_encoder_config.self.condition_to_idx,
                    "optimizer": self.optimizer.state_dict(),
                    "scaler": self.scaler.state_dict(),
                    "error": str(self.error) if self.error else "No errors.",
                },
                f"{checkpoint_path}/model.pt",
            )

            # Save metrics
            torch.save(
                {
                    "metrics": self.metrics,
                    "error": str(self.error) if self.error else "No errors.",
                },
                f"{checkpoint_path}/metrics.pt",
            )

        self.model.to(self.device)
        gc.collect()
        torch.cuda.empty_cache()

        return


def load_training_session(
    save_path: str,
    studies: tp.Dict[str, str] = None,
    data_path: str = "/home/ubuntu/brain-decoding/data",
    clear_cache: bool = False,
    cache_enabled: bool = True,
    max_cache_size: int = 100,
):
    """Loads a training session from the save path."""
    # Load training session config

    if not os.path.exists(save_path):
        raise FileNotFoundError(f"Save path {save_path} does not exist.")

    try:
        load = torch.load(f"{save_path}/model.pt")
        config = load["config"]
        config = TrainingConfigV0().from_dict(config)

        training_session = TrainingSessionV0(
            config=config,
            studies=studies,
            data_path=data_path,
            save_path=save_path,
            clear_cache=clear_cache,
            cache_enabled=cache_enabled,
            max_cache_size=max_cache_size,
        )
        
        # Load model
        training_session.model.load_state_dict(load["model"])
        training_session.optimizer.load_state_dict(load["optimizer"])
        training_session.scaler.load_state_dict(load["scaler"])
        training_session.error = load["error"]

        # Load metrics
        metrics = torch.load(f"{save_path}/metrics.pt")
        training_session.metrics = metrics["metrics"]
        
        if training_session.model.condition_to_idx.keys() != load["conditions"]:
            raise ValueError("Condition to idx mismatch.")

        return training_session

    except Exception as e:
        raise ValueError(f"Error loading training session config, {e}")