#### Code for running a encoder in a supervised way, this way we can figure out if it's capable of learning before we use it for DINO purposes.

In [23]:
%load_ext autoreload
%autoreload 2
import torch
torch.manual_seed(0) # Set seed before importing other modules
import numpy as np
np.random.seed(0)
import random
random.seed(0)
import sys
import os
import torch.nn as nn
from utils.get_data import get_dataloader_augmented
from training_structures.unimodal import train as unimodal_train, test as unimodal_test
import torch.multiprocessing
from models.dino import SpectrogramEncoderMobileViT
# torch.multiprocessing.set_start_method('spawn')

current_path = os.getcwd()
sys.path.append(current_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
class Args_Unimodal:
    def __init__(self):
        self.criterion = nn.CrossEntropyLoss() # Loss function
        self.use_cuda = torch.cuda.is_available()  # Use GPU if available
        self.learning_rate = 0.001 # Initial learning rate
        self.batch_size = 128       # Batch size
        self.epochs = 10          # Total training epochs

In [25]:
class SpectrogramModel(nn.Module):
    def __init__(self, output_dim=256, num_classes = 10):
        super().__init__()
        self.encoder = SpectrogramEncoderMobileViT(output_dim=output_dim)
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder.output_dim, 128),
            nn.ReLU(),
            nn.Dropout(p=0.3),                      # Dropout after ReLU
            nn.Linear(128, num_classes)
        )
    def forward(self, spectrograms=None):
        # Only use spectrograms, ignore images
        if spectrograms is None:
            raise ValueError("SpectrogramEncoder requires spectrogram input")
        features = self.encoder(images=None, spectrograms=spectrograms)
        return self.classifier(features)

In [26]:
model = SpectrogramModel().to(device)
modalnum = 1 # For audio encoder based model



In [27]:
args_audio = Args_Unimodal()
dir_train_logs = "training_logs/audio/"
dir_test_logs = "test_logs/audio/"
# if directory doesn't exist, make it:
for path in [dir_train_logs, dir_test_logs]:
    os.makedirs(os.path.dirname(path), exist_ok=True)

for aug_type in [
    # "aliased", 
    "burst_noise", 
    # "distorted", 
    # "extreme_noise", 
    # "multi_band"
    ]:
        
    model_name = f'model_audio_augmented_{aug_type}_{model.encoder.__class__.__name__}.pt'
    traindata,validdata,testdata = get_dataloader_augmented(f'{current_path}/data/avmnist', type=aug_type, batch_size=args_audio.batch_size)

    log_file = f"{dir_train_logs}training_log_audio_{aug_type}.csv"
    test_log_file = f"{dir_test_logs}test_results_audio_{aug_type}.csv"

    print(f"Training with augmentation type: {aug_type}")
    
    model_name = unimodal_train(model, args_audio, traindata, device, modalnum=modalnum, val_loader=validdata, 
                    log_file=log_file, save_model=model_name)

    print(f"Testing with augmentation type: {aug_type}")

    model= torch.load(model_name)
    _ = unimodal_test(model, testdata, args_audio.criterion, device, 
                      modalnum=modalnum, test_log_file=test_log_file)

Training with augmentation type: burst_noise
Epoch 1/10, Loss: 1.1983
Validation Loss: 3.2502, Accuracy: 10.40%
Saving Best
Epoch 2/10, Loss: 0.4739
Validation Loss: 0.5017, Accuracy: 84.16%
Saving Best
Epoch 3/10, Loss: 0.3312
Validation Loss: 0.3533, Accuracy: 88.80%
Saving Best
Epoch 4/10, Loss: 0.2611
Validation Loss: 0.3746, Accuracy: 88.88%
Saving Best
Epoch 5/10, Loss: 0.2073
Validation Loss: 0.3554, Accuracy: 88.80%
Epoch 6/10, Loss: 0.1774
Validation Loss: 0.3933, Accuracy: 89.38%
Saving Best
Epoch 7/10, Loss: 0.1528
Validation Loss: 0.3295, Accuracy: 90.76%
Saving Best
Epoch 8/10, Loss: 0.1303
Validation Loss: 0.4048, Accuracy: 89.68%
Epoch 9/10, Loss: 0.1151
Validation Loss: 0.2951, Accuracy: 91.70%
Saving Best
Epoch 10/10, Loss: 0.0973
Validation Loss: 0.3152, Accuracy: 91.34%
Training Complete!
Testing with augmentation type: burst_noise


  model= torch.load(model_name)


Test Loss: 0.4152, Test Accuracy: 88.83%
