In [42]:
import os
from pathlib import Path
import yaml
import argparse
import pandas as pd
import lightning.pytorch as pl
import optuna
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping

from update_config import update_hardware_config
from objective_augment import objective
from datetime import datetime
import time
from utils.debugging import add_debugging_to_lightning_module
import torch
import numpy as np
import lightning.pytorch as pl
from torchinfo import summary # For calculating GFLOPs

from models.dino import MultiModalDINOLightning, CrossAttentionViTMultiModalEncoder, \
UniModalDINOLightning, ImageEncoder, \
SpectrogramEncoder, SpectrogramEncoderCentral, SpectrogramEncoderLSTM, SpectrogramEncoderResidual, \
SpectrogramEncoderViT, SpectrogramEncoderMobileViT, SpectrogramEncoderResNet
from utils.get_data import AVMNISTDinoDataModule, TimeWarpWithStretch, GroupedMasking, GaussianNoise
from datetime import datetime
from utils.plots_trials import create_plots_for_study, load_all_versions, process_metrics, save_versions_to_csv, create_dir, plot_loss
from utils.visualisations import pca_plot_dataloaders, pca_plot_multiclass, visualize_prediction_matrix
from utils.get_data import get_dataloader_augmented
from training_structures.dino_train import train_downstream, train_knn_classifier
from optuna.storages import JournalStorage, JournalFileStorage, RetryFailedTrialCallback, RDBStorage
import torchvision.transforms as transforms
import torchaudio.transforms as audio_transforms

In [None]:
class MultiModalAugmentation:
    def __init__(self, n_global_views=2, n_local_views=4, 
                 global_spec_size=112, local_spec_size=112, augment_values=None):
        
        self.n_local_views = n_local_views
        self.n_global_views = n_global_views
        self.global_spec_size = global_spec_size
        self.local_spec_size = local_spec_size

        self._initialize_transforms(augment_values)

    def _initialize_transforms(self, augment_values=None):
        global_image_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=28, scale=(0.75, 1.0), antialias=True),
            transforms.RandomRotation(degrees=5),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        ])
        local_image_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=28, scale=(0.3, 0.75), antialias=True),
            transforms.RandomRotation(degrees=15),
            transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.2)),
            transforms.RandomErasing(p=0.3, scale=(0.02, 0.15)),
        ])
        if augment_values is None:
            self.global_transforms = {
                'image': global_image_transforms,
                'audio': transforms.Compose([
                    # Keep original size but apply moderate augmentations
                    transforms.RandomResizedCrop(size=(self.global_spec_size, self.global_spec_size), scale=(0.8, 1.0), antialias=True),
                    transforms.RandomApply([TimeWarpWithStretch(min_factor=0.9, max_factor=1.1, target_length=self.global_spec_size)], p=0.5),  # Add time-stretch (global)
                    # Moderate frequency/time masking
                    transforms.RandomApply([audio_transforms.FrequencyMasking(freq_mask_param=15)], p=0.3),
                    transforms.RandomApply([audio_transforms.TimeMasking(time_mask_param=15)], p=0.3),
                    # Add gentle pitch shifting (vertical shift in spectrogram)
                    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                    # Grouped Masking
                    transforms.RandomApply([GroupedMasking(mask_ratio=0.25)], p=0.25),
            ]),
            }

            self.local_transforms = {
                'image': local_image_transforms,
                'audio': transforms.Compose([
                    # Apply crop to create local view but maintain overall dimensions
                    transforms.RandomResizedCrop(size=(self.local_spec_size, self.local_spec_size), scale=(0.5, 0.9), antialias=True),
                    transforms.RandomApply([TimeWarpWithStretch(min_factor=0.7, max_factor=1.3, target_length=self.local_spec_size)], p=0.7),  # Add time-stretch (local)
                    # Much stronger frequency masking (vertical strips)
                    transforms.RandomApply([audio_transforms.FrequencyMasking(freq_mask_param=40)], p=0.6),
                    # Much stronger time masking (horizontal strips)
                    transforms.RandomApply([audio_transforms.TimeMasking(time_mask_param=40)], p=0.6),
                    # Multiple frequency masks
                    # transforms.RandomApply([
                    #     transforms.Compose([
                    #         audio_transforms.FrequencyMasking(freq_mask_param=25),
                    #         audio_transforms.FrequencyMasking(freq_mask_param=15)
                    #     ])
                    # ], p=0.25), # low p means it's a rare occurance that this transform is applied
                    # Multiple time masks
                    # transforms.RandomApply([
                    #     transforms.Compose([
                    #         audio_transforms.TimeMasking(time_mask_param=25),
                    #         audio_transforms.TimeMasking(time_mask_param=15)
                    #     ])
                    # ], p=0.25),
                    # Add more aggressive affine transformations (simulates pitch/time stretching in spectrogram)
                    # transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.7, 1.3), shear=10),
                    # Add noise
                    transforms.RandomApply([GaussianNoise(std=0.2)], p=0.5),
                    # Add random erasing (simulates dropout of frequency/time regions)
                    # transforms.RandomErasing(p=0.5, scale=(0.02, 0.2)),
                    # Add another random erasing for multiple masks
                    # transforms.RandomErasing(p=0.3, scale=(0.02, 0.15)),
                    # Grouped masking
                    transforms.RandomApply([GroupedMasking(mask_ratio=0.75)], p=0.9),
                ]),
            }
        else:
            aug_to_class = {
                "time_warp": TimeWarpWithStretch,
                "frequency_mask": audio_transforms.FrequencyMasking,
                "time_mask": audio_transforms.TimeMasking,
                "grouped_masking": GroupedMasking,
                "gaussian_noise": GaussianNoise,
                # "random_erasing": transforms.RandomErasing,
            }

            augment_list = {"global_views": [], "local_views": []}
            for aug_type in ["global_views", "local_views"]:
                augmentations = augment_values['augmentations'][aug_type]
                augmentation_probabilities = augment_values['augmentation_probabilities'][aug_type]
                for aug in augmentations.keys():
                    print("Making a transform:")
                    print(f"Final transform: transforms.RandomApply({aug_to_class[aug]}({augmentations[aug]}), p={augmentation_probabilities[aug]})")
                    augment_list[aug_type].append(transforms.RandomApply([aug_to_class[aug](**augmentations[aug])], p=augmentation_probabilities[aug]))

            self.global_transforms = {
                'image': global_image_transforms,
                'audio': transforms.Compose(augment_list['global_views']),
            }

            self.local_transforms = {
                'image': local_image_transforms,
                'audio': transforms.Compose(augment_list['local_views']),
            }

    @torch.no_grad()
    def __call__(self, images, audios):
        global_images = []
        global_audios = []

        for _ in range(self.n_global_views):
            # Process global views
            global_images.append(self.global_transforms['image'](images))
            global_audios.append(self.global_transforms['audio'](audios))
        
        global_images = torch.stack(global_images, dim=0)
        global_audios = torch.stack(global_audios, dim=0)

        # Generate multiple diverse local views
        local_images = []
        local_audios = []
        
        for _ in range(self.n_local_views):
            local_images.append(self.local_transforms['image'](images))
            local_audios.append(self.local_transforms['audio'](audios))
            
        local_images = torch.stack(local_images, dim=0)
        local_audios = torch.stack(local_audios, dim=0)
        
        return global_images, global_audios, local_images, local_audios

In [58]:
config_path = os.path.join(Path.cwd(), 'configs', 'config_augments.yaml')
encoder_class = SpectrogramEncoderMobileViT

config = yaml.safe_load(open(config_path))
config = update_hardware_config(config)

pl.seed_everything(config['experiment']['seed'], workers=True)

# Model remains static throughout the search/experiment, however augments change dynamically based on config
model = UniModalDINOLightning(
        encoder_class = encoder_class,
        data_dir="C:/Users/Ward/Desktop/vub-github/Thesis-project/AVMNIST_Experiments/data/avmnist/", 
        dropout=config['hyperparameters']['dropout'],
        learning_rate = config['hyperparameters']['learning_rate'],
        projection_dim = config['hyperparameters']['projection_dim'],
        output_dim = config['hyperparameters']['output_dim'],
        momentum = config['hyperparameters']['momentum'],
        center_momentum = config['hyperparameters']['center_momentum'],
        teacher_temperature=config['hyperparameters']['teacher_temperature'],
        weight_decay=config['hyperparameters']['weight_decay'],
        cosine_loss_alpha=config['hyperparameters']['cosine_loss_alpha'],
        num_epochs=config['hyperparameters']['num_epochs'],
        data_augmentation=config['hyperparameters'].get('data_augmentation', 'burst_noise')
    )

def _process_augment_config(trial, config, is_hyperparameter_search=False):
    """
    Process augmentation configuration for either:
    - Hyperparameter search (using Optuna trial)
    - Final training (using best_augments from config)
    
    Args:
        trial: Optuna trial object (None for final training)
        config: Full configuration dictionary
        is_hyperparameter_search: Boolean flag for mode selection
    
    Returns:
        Dictionary with:
        - 'augmentations': Dict without probability parameters
        - 'augmentation_probabilities': Dict with only probabilities
    """
    if is_hyperparameter_search:
        # Hyperparameter search mode - sample from ranges
        augmentations = {'global_views': {}, 'local_views': {}}
        augmentation_probabilities = {'global_views': {}, 'local_views': {}}

        for view_setting in ['global_views', 'local_views']:
            view_params = config["optuna"]["augmentations"][view_setting].items()
            for aug, params in view_params:
                aug_params = {}
                aug_prob = None
                
                for param_name, param_info in params.items():
                    if param_name == "p":
                        aug_prob = param_info["low"]
                        augmentation_probabilities[view_setting][aug] = aug_prob
                    else:
                        if param_info["type"] == "uniform":
                            aug_params[param_name] = param_info["low"]
                        elif param_info["type"] == "int":
                            aug_params[param_name] = param_info["low"]
                
                if aug_params:  # Only add if there are non-p parameters
                    augmentations[view_setting][aug] = aug_params

        return {
            "augmentations": augmentations,
            "augmentation_probabilities": augmentation_probabilities
        }
    else:
        # Final training mode - use fixed best_augments
        if "best_augments" not in config:
            raise ValueError("best_augments not found in config for final training")
            
        # Create structures without 'p' parameters
        augmentations = {
            'global_views': {},
            'local_views': {}
        }
        augmentation_probabilities = {
            'global_views': {},
            'local_views': {}
        }
        
        for view_setting in ['global_views', 'local_views']:
            for aug, params in config["best_augments"][view_setting].items():
                # Filter out 'p' from augmentations
                aug_params = {k: v for k, v in params.items() if k != 'p'}
                if aug_params:  # Only add if there are non-p parameters
                    augmentations[view_setting][aug] = aug_params
                
                # Store probability separately
                if 'p' in params:
                    augmentation_probabilities[view_setting][aug] = params['p']
        
        return {
            "augmentations": augmentations,
            "augmentation_probabilities": augmentation_probabilities
        }

augment_values = _process_augment_config(None, config, is_hyperparameter_search=True)
multimodal_augments = MultiModalAugmentation(augment_values=augment_values)
print(augment_values)
print(multimodal_augments)

Seed set to 0


Updated config: {'device': 'cpu', 'num_gpus': 0, 'num_workers': 4}
Making a transform:
Final transform: transforms.RandomApply(<class 'utils.get_data.TimeWarpWithStretch'>({'min_factor': 0.9, 'max_factor': 1.1}), p=0.3)
Making a transform:
Final transform: transforms.RandomApply(<class 'torchaudio.transforms._transforms.FrequencyMasking'>({'freq_mask_param': 5}), p=0.3)
Making a transform:
Final transform: transforms.RandomApply(<class 'torchaudio.transforms._transforms.TimeMasking'>({'time_mask_param': 5}), p=0.3)
Making a transform:
Final transform: transforms.RandomApply(<class 'utils.get_data.GaussianNoise'>({'std': 0.01}), p=0.1)
Making a transform:
Final transform: transforms.RandomApply(<class 'utils.get_data.TimeWarpWithStretch'>({'min_factor': 0.5, 'max_factor': 1.3}), p=0.8)
Making a transform:
Final transform: transforms.RandomApply(<class 'torchaudio.transforms._transforms.FrequencyMasking'>({'freq_mask_param': 20}), p=0.8)
Making a transform:
Final transform: transforms.Ra

