In [None]:
# from utils.compression import compress_directories, decompress_directories

# base_path = "downloaded_data/gwilliams"
# destination_path = "data/gwilliams"

# decompress_directories(
#     base_path,
#     destination_path,
#     checksum_file_name="checksums.txt",
#     delete_compressed_files=True,
#     num_workers=None
# )

In [None]:
from config.simpleconv_config import SimpleConvConfig
from models.simpleconv import SimpleConv
from studies.study_factory import StudyFactory
from utils.pre_processor import PreProcessor
from utils.fetch import fetch_audio_and_brain_pairs
import typing as tp
import json
from itertools import product
from torch.optim import AdamW, Adam
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import os
import logging

import torch
device = 'cuda'

from config import SimpleConvConfig, Config

class TrainingConfigV0(Config):
    def __init__(
        self,
        brain_encoder_config: SimpleConvConfig,
        data_partition: tp.Dict[str, tp.Dict[str, tp.List[str]]],
        # Pre-processing parameters
        # Brain
        brain_sample_rate: int = 100,
        band_pass_filter: tp.Tuple[str, tp.Tuple[int, int]] = {"all": (0.5, 100)},
        max_random_shift: float = 2.0,
        window_size: int = 4,
        window_stride: int = 1,
        brain_clipping: float = 20,
        baseline_window: int = 0.5,
        notch_filter: bool = True,
        # Audio
        audio_model: str = "openai/whisper-large-v3",
        audio_sample_rate: int = 16000,
        hop_length: int = 160,
        
        # Hyperparameters
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-4,
    ):
        self.brain_encoder_config = brain_encoder_config
        # key: study_name, value: dict with keys: "testing_subjects", "testing_tasks",
        # where each value is a list of int. Ones not specified in either lists are 
        # used for training.
        self.data_partition = data_partition
        
        # Pre-processing parameters
        # Brain
        self.brain_sample_rate = brain_sample_rate
        self.band_pass_filter = band_pass_filter
        self.max_random_shift = max_random_shift
        self.window_size = window_size
        self.window_stride = window_stride
        self.baseline_window = baseline_window
        self.notch_filter = notch_filter
        # Audio
        self.audio_model = audio_model
        self.audio_sample_rate = audio_sample_rate
        self.hop_length = hop_length
        self.brain_clipping = brain_clipping
        
        # Hyperparameters
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        
    # does not overide parent method
    def to_dict_(self):
        brain_encoder_config = self.brain_encoder_config.to_dict()
        config = self.to_dict()
        config["brain_encoder_config"] = brain_encoder_config
        return config

        
class TrainingSessionV0:
    def __init__(self,
        config: TrainingConfigV0,
        studies: tp.List[str],
        data_path: str = '/home/ubuntu/brain-decoding/data',
        save_path: str = '/home/ubuntu/brain-decoding/saves',
    ):
        """Initializes a training session with the provided configuration and data.

        Arguments:
            config -- The configuration for the training session.
            studies -- list of studies to train on. Partition policy determined in TrainingConfig
            data_path -- The path to the data directory.
            save_path -- The path to the directory where the model and logs will be saved.
        """
        assert len(studies) > 0, "At least one study root path must be provided"
        assert all(
            os.path.exists(data_path + "/" + study) for study in studies
        ), "All study root paths must exist"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        logging.basicConfig(
            filename=os.path.join(save_path, "training_log.log"),
            level=logging.INFO,
            format="%(asctime)s %(message)s",
            filemode="w",
        )
        self.logger = logging.getLogger()

        self.config = config
        self.data_path = data_path
        self.save_path = save_path
        
        # Create studies accessor
        self.studies = {}
        for study in studies:
            path = os.path.join(data_path, study)
            try:
                self.studies[study] = StudyFactory.create_study(study, path)
            except ValueError as e:
                self.logger.error(f"Error loading study {study}: {e}")
                
        # Create preprocessor
        self.pre_processor = PreProcessor(
            brain_sample_rate=config.brain_sample_rate, 
            audio_model=config.audio_model
        )

        self.dataset = {
            "train": [],
            "test": {
                "unseen_subject": [],
                "unseen_task": [],
                "unseen_both": [],
            },
        }
        
        self.partition_data()
        
        self.metrics = {
            "train": [],
            "test": {
                "unseen_subject": [],
                "unseen_task": [],
                "unseen_both": [],
            },
        }

        self.model = SimpleConv(self.config.brain_encoder_config)
        self.error = None
        self.optimizer = Adam(
            self.model.parameters(), 
            lr=self.config.learning_rate, 
            weight_decay=self.config.weight_decay
        )

    def partition_data(self):
        """
        Partitions the data into training and various testing sets, based on 
        the named holdout sessions and tasks specified in TrainingConfig
        """
        
        for study_name, study in self.studies.items():
            
            if study_name not in self.config.data_partition:
                raise ValueError(f"Study {study_name} not found in data partition")

            data_partition = self.config.data_partition[study_name]
            
            for subject, task, session in product(
                [i for i in range(len(study.subjects_list))],
                [i for i in range(len(study.tasks))],
                [i for i in range(len(study.sessions))],
            ):
                # If recording exists
                try:
                    recording = study.recordings[subject][task][session]
                except IndexError:
                    self.logger.error(f"Recording not found for {study_name} {subject} {task} {session}")
                    continue
                
                # Unseen both and task
                if subject in data_partition["testing_subjects"]:
                    if task in data_partition["testing_tasks"]:
                        self.dataset["test"]["unseen_both"].append((study_name, subject, task, session))
                    else:
                        self.dataset["test"]["unseen_task"].append((study_name, subject, task, session))
                # Unseen subject and train
                else:
                    if task in data_partition["testing_tasks"]:
                        self.dataset["test"]["unseen_subject"].append((study_name, subject, task, session))
                    else:
                        self.dataset["train"].append((study_name, subject, task, session))
                          
        self.log_print(
            f'Data partitioned on studies {list(self.studies.keys())}. Recordings:'
        )
        self.log_print(
            f"Train: {len(self.dataset['train'])}, Unseen Task: {len(self.dataset['test']['unseen_task'])}, Unseen Subject: {len(self.dataset['test']['unseen_subject'])}, Unseen Both: {len(self.dataset['test']['unseen_both'])}.\n"
        )
        
    def train(self, device: str):
        pass
    
    def test(self, device: str):
        pass
    
    def pre_process_all_tasks(self):
        pass
    
    def load_task(self):
        pass
    
    def run_task(self):
        pass
    
    def save(self):
        pass
    
    def discard_nan(self):
        pass
        
    def log_print(self, message):
        print(message)
        self.logger.info(message) 
        
def load_training_session(): pass
        
training_config = TrainingConfigV0(
    brain_encoder_config=SimpleConvConfig(),
    data_partition={
        "gwilliams": {
            "testing_subjects": [19, 20, 21],
            "testing_tasks": [0],
        },
        # "schoffelen": {
        #     "testing_subjects": [],
        #     "testing_tasks": [8, 9],
        # },
    },
    learning_rate=3e-4,
    weight_decay=1e-4,
)      

session = TrainingSessionV0(
    training_config,
    studies=["gwilliams"], # "schoffelen"
    data_path='/home/ubuntu/brain-decoding/data',
    save_path='/home/ubuntu/brain-decoding/saves'
)

Data partitioned on studies ['gwilliams']. Recordings:
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.


SimpleConv: 
	Params: 14432128
	Conv blocks: 5
	Trans layers: 0


In [8]:
brain_sample_rate = 100
frequency_bands = {"all": (0.5, 100)}
subject, task, session = 0, 3, 0
seed = 42
max_random_shift = 1
window_size = 4
n_jobs = -1

study = StudyFactory().create_study("gwilliams", path="data/gwilliams")
pre_processor = PreProcessor(
    brain_sample_rate=brain_sample_rate,
)

In [10]:
study.subjects_list[19]

'20'

In [4]:
brain_segments, audio_segments, layout = fetch_audio_and_brain_pairs(
    subject=subject,
    task=task,
    session=session,
    max_random_shift=max_random_shift,
    window_size=window_size,
    study=study,
    pre_processor=pre_processor,
    frequency_bands=frequency_bands,
    audio_sample_rate=16000,
    hop_length=160,
    n_jobs=n_jobs,
)
brain_segments['all'].shape, audio_segments.shape, layout.shape

(torch.Size([1094, 208, 400]),
 torch.Size([1094, 128, 400]),
 torch.Size([208, 2]))

In [None]:
model = SimpleConv(SimpleConvConfig(transformer_layers=0)).to(device)


SimpleConv: 
	Params: 14432128
	Conv blocks: 5
	Trans layers: 0


In [None]:
output = model(
    {"meg": brain_segments["all"].to(device)},
    layout=layout.to(device),
    subjects=torch.full((brain_segments["all"].shape[0],), subject).to(device),
)

In [None]:
output.shape

torch.Size([199, 128, 400])