In [1]:
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

In [None]:
import typing as tp
import gc
import random
import time
from tqdm import tqdm
import typing as tp
import json
from torch.optim import AdamW
import os
import torch

from dataloader import DataLoader
from dataloader.audio_batch import AudioBatch
from config import TrainingConfigV1
from losses.mse import mse_loss_per_batch
from losses.cos_sim import cosine_similarity_loss
from train.training_session import TrainingSession
from models.whisper_alignment import WhisperAlignment

device = "cuda"


class TrainingSessionV1(TrainingSession):

    def __init__(
        self,
        config: TrainingConfigV1 = None,
        studies: tp.Dict[str, str] = None,
        data_path: str = "/home/ubuntu/brain-decoding/data",
        save_path: str = "/home/ubuntu/brain-decoding/saves",
        clear_cache: bool = False,
        max_cache_size: int = 100,
        cache_name: str = "cache",
    ):
        """
        Initializes a training session with the provided configuration and data.
        This version deals with audio batches for Whisper latent alignment,
        architecture exploration, and dataset integration.

        Arguments:
            config -- The configuration for the training session.
            studies -- dict of studies, batch type. Partition policy determined in TrainingConfig
                    Batch type determines how to load data from study.

            data_path -- The path to the data directory.
            save_path -- The path to the directory where the model and logs will be saved.
            clear_cache -- Whether to clear the cache for the studies.
            max_cache_size -- The maximum number of stimulis in cache.
        """
        super().__init__(
            config=config,
            studies=studies,
            data_path=data_path,
            save_path=save_path,
            clear_cache=clear_cache,
            cache_enabled=True,
            max_cache_size=max_cache_size,
            cache_name=cache_name,
        )

        # MODEL
        self.model = WhisperAlignment(
            brain_module_config=config.brain_encoder_config,
            adalora_config=config.adalora_config,
            layers_to_align=config.latent_alignment_layers,
            use_compile=False,
        )

        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        self.scaler = torch.amp.GradScaler(device=device)

        self.clip_loss, self.mse_loss, self.cosine_similarity_loss = (
            self.model.brain_module.clip_loss,
            mse_loss_per_batch,
            cosine_similarity_loss,
        )

    def train(
        self,
        device: str,
        buffer_size: int,
        num_workers: int,
        max_cache_size: int,
        current_epoch: int = 0,
    ):
        """Max cache size for the cache dir in GB"""

        # Set all training parameters
        self.device = device
        self.model.to(device)
        self.clip_loss.to(device)
        training_size = len(self.dataset["train"])

        for epoch in range(current_epoch + 1, self.config.epochs + 1):
            try:
                self.model.to(device).train()
                epoch_start_time = time.time()

                # Shuffle for each epoch, and start fetching
                epoch_training_dataset = self.dataset["train"].copy()

                # Fetch recordings
                dataloader = self.get_dataloader(
                    buffer_size=buffer_size,
                    num_workers=num_workers,
                    max_cache_size=max_cache_size,
                )

                # For reproducibility
                self.set_seed(int(self.config.seed + epoch))
                random.shuffle(epoch_training_dataset)
                dataloader.start_fetching(epoch_training_dataset, cache=True)

            except Exception as e:
                self.log_print(f"Error in epoch {epoch} during initialization, {e}")
                self.save(f"error_epoch_{epoch}")

            pbar = tqdm(
                total=len(epoch_training_dataset), desc="Training Epoch " + str(epoch)
            )

            all_metrics, total_batches = [], 0

            while True:

                # Combines multiple recordings so that there is some mixing of
                # studies and subjects in each batch
                multi_batch = []

        #     self.model.encoder.update_and_allocate(step)

    def get_dataloader(self, buffer_size, num_workers, max_cache_size):
        dataloader = DataLoader(
            buffer_size=buffer_size,
            max_cache_size_gb=max_cache_size,
            cache_dir="cache",
            notch_filter=self.config.notch_filter,
            frequency_bands=self.config.frequency_bands,
            scaling=self.config.scaling,
            brain_clipping=self.config.brain_clipping,
            baseline_window=self.config.baseline_window,
            new_freq=self.config.new_freq,
            delay=self.config.delay,
            batch_types={"audio": num_workers},
            batch_kwargs={
                "audio": {
                    "max_random_shift": self.config.max_random_shift,
                    "window_size": self.config.window_size,
                    "window_stride": self.config.window_stride,
                    "audio_sample_rate": self.config.audio_sample_rate,
                    "hop_length": self.config.hop_length,
                    "audio_processor": self.config.audio_model,
                }
            },
        )
        return dataloader

    def discard_nan(
        self,
        brain: torch.Tensor,
        audio: torch.Tensor,
    ):
        """
        If any nan in brain or audio data, discard the batch.

        Arguments:
            brain -- The brain data tensor, [B, C, T]
            audio -- The audio data, [B, mel_bins, T]
        """

        valid_mask = ~(
            torch.isnan(brain).any(dim=(1, 2)) | torch.isnan(audio).any(dim=(1, 2))
        )

        if valid_mask.all():
            return brain, audio

        # Apply the same mask to both tensors
        filtered_brain = brain[valid_mask]
        filtered_audio = audio[valid_mask]

        if filtered_brain.shape[0] != filtered_audio.shape[0]:
            raise ValueError(
                "Filtered brain and audio data must have the same number of samples"
            )

        return filtered_brain, filtered_audio

    def pre_process_all_recordings(
        self, buffer_size: int, num_workers: int, max_cache_size: int
    ):
        """Pre-processes all data and saves as .pt in cache at once."""

        if self.recordings is None:
            self.partition_datasets()

        dataloader = self.get_dataloader(buffer_size, num_workers, max_cache_size)

        total_recordings, remaining = len(self.recordings), len(self.recordings)
        pbar = tqdm(total=total_recordings, desc="Loading recordings")

        dataloader.start_fetching(self.recordings)

        while True:
            recording = dataloader.get_recording()
            if recording is None:
                break
            remaining -= 1
            pbar.update(1)

    def save(self, name: str):
        pass


def load_training_session(
    save_path: str,
    studies: tp.Dict[str, str] = None,
    data_path: str = "/home/ubuntu/brain-decoding/data",
    clear_cache: bool = False,
    cache_enabled: bool = True,
    max_cache_size: int = 100,
):
    pass

In [2]:
from peft import AdaLoraConfig

adalora_config = AdaLoraConfig(
    peft_type="ADALORA",
    task_type="CAUSAL_LM",  # or "SPEECH_RECOGNITION" if needed
    target_modules=["q_proj", "v_proj"],  # which modules to adapt
    init_r=12,  # initial rank
    target_r=8,  # final average rank
    tinit=100,  # begin rank updates after 100 steps
    tfinal=1000,  # finish rank updates by step 1000
    deltaT=100,  # re-allocate every 100 steps
    lora_alpha=32,
    lora_dropout=0.1,
    total_step=2000,  # your total training steps
)