In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../")
from utils.reproducibility import set_seed
set_seed(1)
import os
import torch
import numpy as np
import random
from models.unimodal import CentralUnimodalImage, CentralUnimodalAudio, UnimodalImage, UnimodalAudio
import torch.nn as nn
import torch.optim as optim
from utils.get_data import get_dataloader_augmented, load_results_from_csv, AVMNISTDataModule, AVMNISTDinoDataModule, AVMNISTDataset
from training_structures.unimodal import train as unimodal_train, test as unimodal_test
from models.centralnet.centralnet import SimpleAV_CentralNet as CentralNet
from training_structures.centralnet_train import train_centralnet, test_centralnet
import matplotlib.pyplot as plt
from utils.visualisations import show_images, show_images_augmentations, \
evaluate_results, plot_training_results_from_csv, plot_training_results_from_csvs, \
pca_plot_multiclass, tsne_plot_multiclass, pca_plot_dataloaders
from torchvision import transforms, datasets
from torchmetrics.classification import Accuracy
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
import lightning.pytorch as pl
from lightning.pytorch.strategies import DDPStrategy
import torch.multiprocessing
# torch.multiprocessing.set_start_method('spawn')

current_path = os.getcwd()
parent_dir = os.path.dirname(current_path)
sys.path.append(current_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_dir = os.path.join(parent_dir, "data/avmnist/")
dir_logs = os.path.join(parent_dir, "supervised_results/")

  from .autonotebook import tqdm as notebook_tqdm


#### Initialize Dataloaders

In [2]:
mnist_data = AVMNISTDataModule(data_dir=data_dir, batch_size=128, num_workers=0)
mnist_data.setup()

In [3]:
traindata,validdata,testdata = mnist_data.train_dataloader(), mnist_data.val_dataloader(), mnist_data.test_dataloader()

#### Hyperparameters setting

In [4]:
class Args_Unimodal:
    def __init__(self):
        self.criterion = nn.CrossEntropyLoss() # Loss function
        self.use_cuda = torch.cuda.is_available()  # Use GPU if available
        self.learning_rate = 0.001 # Initial learning rate
        self.batch_size = 128       # Batch size
        self.epochs = 1          # Total training epochs

In [5]:
class Args_CentralNet:
    def __init__(self):
        self.channels = 16         # Base convolution channels
        self.fusingmix = '11,32,53' # Fusion strategy
        self.fusetype = 'wsum'     # Weighted sum fusion
        self.num_outputs = 10      # Number of classes (AVMNIST)
        self.criterion = nn.CrossEntropyLoss() # Loss function
        self.use_cuda = torch.cuda.is_available()  # Use GPU if available
        self.learning_rate = 0.001 # Initial learning rate
        self.batch_size = 128       # Batch size
        self.epochs = 1          # Total training epochs

#### Unimodal Image

In [None]:
seeds = [1, 2, 3]
test_accuracies = []
args_image = Args_Unimodal()

for seed in seeds:
    print(f"Running seed: {seed}")
    set_seed(seed)

    # Setup log dir for this run
    run_dir = f"{dir_logs}unimodal_image/seed_{seed}/"
    os.makedirs(run_dir, exist_ok=True)

    # Save last model only
    checkpoint_callback = ModelCheckpoint(
        dirpath=run_dir,
        filename='unimodal_image-last',
        save_last=True
    )

    # Logger per seed
    logger = CSVLogger(dir_logs, name=f"unimodal_image/seed_{seed}")

    # Initialize model
    image_model = UnimodalImage(with_head=True, num_epochs=args_image.epochs)

    # Trainer
    trainer = pl.Trainer(
        max_epochs=args_image.epochs,
        logger=logger,
        callbacks=[checkpoint_callback],
        accelerator="gpu",
        deterministic=True,
        # strategy="ddp" if multi-GPU
    )

    # Train
    trainer.fit(image_model, mnist_data)

    trained_model = image_model.model

    # Evaluate
    modalnum = 0
    test_log_file = os.path.join(run_dir, "test_log.csv")

    avg_loss, accuracy, all_labels, all_probs = unimodal_test(
        trained_model,
        testdata,
        args_image.criterion,
        device,
        modalnum=modalnum,
        test_log_file=test_log_file
    )

    test_accuracies.append(accuracy)

# Final summary
mean_acc = np.mean(test_accuracies)
std_acc = np.std(test_accuracies)

print(f"Mean Test Accuracy over {len(seeds)} seeds: {mean_acc:.4f}")
print(f"Std Dev of Accuracy: {std_acc:.4f}")

In [None]:
pca_plot_path = f"{dir_logs}unimodal_image/plots/"

class ImageWrapper(nn.Module):
    """
    Since centralnet returns image_out, audio_out, fusion_out, 
    we need to wrap it to only return the fusion_out. Since this is what we want to visualize.
    """
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def forward(self, image, audio):
        out = self.base_model(image) # Assuming the model takes audio first and then image
        return out

image_wrapper = ImageWrapper(trained_model)

_ = pca_plot_dataloaders(image_wrapper, testdata, selected_digits=[5, 8], dirpath=pca_plot_path, 
                         show_plots=False, is_dino_based=False)
_ = pca_plot_multiclass(image_wrapper, testdata, selected_digits=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 
                        dirpath=pca_plot_path, show_plots=False, is_dino_based=False)
_ = tsne_plot_multiclass(image_wrapper, testdata, selected_digits=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 
                         dirpath=pca_plot_path, show_plots=False, random_seed=0, is_dino_based=False)

In [None]:
checkpoint_callback = ModelCheckpoint( # save the last model
    dirpath=f'{dir_logs}unimodal_image/',
    filename='unimodal_image-last',
    save_last=True,
)
logger = CSVLogger(dir_logs, name='unimodal_image')
image_model = UnimodalImage(with_head=True, num_epochs=100)

In [None]:
# ddp = DDPStrategy()
trainer = pl.Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu",
    # strategy= ddp  # Add this line to enable data parallelism
    )
trainer.fit(image_model, mnist_data)

In [None]:
args_audio = Args_Unimodal()
model = image_model.model
modalnum = 0
test_log_file = f'{dir_logs}unimodal_image/test_log.csv'
_ = unimodal_test(model, testdata, args_audio.criterion, device, 
                    modalnum=modalnum, test_log_file=test_log_file)

#### Unimodal Audio

In [None]:
checkpoint_callback = ModelCheckpoint( # save the best model based on minimum validation loss?
    monitor='val_loss',
    dirpath='centralnet_results/unimodal_audio/',
    filename='unimodal_audio-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)
logger = CSVLogger('centralnet_results/', name='unimodal_audio')
audio_model = UnimodalAudio(with_head=True)

In [None]:
# ddp = DDPStrategy()
trainer = pl.Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu",
    # strategy= ddp  # Add this line to enable data parallelism
    )
trainer.fit(audio_model, mnist_data)

In [None]:
args_audio = Args_Unimodal()
model = audio_model.model
_ = unimodal_test(model, testdata, args_audio.criterion, device, 
                    modalnum=modalnum, test_log_file=test_log_file)

#### Multimodal CentralNet

In [7]:
args_central = Args_CentralNet()
if __name__ == "__main__":
    for aug_type in [
        # "aliased", 
        "burst_noise", 
        # "distorted", 
        # "extreme_noise", 
        # "multi_band"
        ]:
        print(f"Training with augmentation type: {aug_type}")
        model = CentralNet(args_central, audio_channels=1, image_channels=1).to(device)  # Assuming grayscale input for both
        model_name = f'model_central_augmented_{aug_type}.pt'

        dir_train_logs = "training_logs/central/"
        if not os.path.exists(dir_train_logs):
            os.makedirs(dir_train_logs)
        log_file = f"{dir_train_logs}training_log_central_{aug_type}.csv"

        model_name = train_centralnet(model, args_central, traindata, device, val_loader=validdata, 
                                      log_file=log_file, save_model=model_name)
        
        dir_test_logs = "test_logs/central/"
        if not os.path.exists(dir_test_logs):
            os.makedirs(dir_test_logs)
        test_log_file = f"{dir_test_logs}test_results_central_{aug_type}.csv"

        print(f"Testing with augmentation type: {aug_type}")
        model = torch.load(model_name)
        test_loss, test_accuracy, all_labels, all_probs = test_centralnet(model, testdata, 
                                                                        args_central.criterion, device, 
                                                                        test_log_file=test_log_file)

Training with augmentation type: burst_noise
Epoch 1/1, Loss: 3.4985
Validation Loss: 0.3932, Accuracy: 87.78%
Saving Best
Training Complete!
Testing with augmentation type: burst_noise


  model = torch.load(model_name)


Test Loss: 0.4429, Test Accuracy: 85.76%


In [None]:
class FusionOnlyWrapper(nn.Module):
    """
    Since centralnet returns image_out, audio_out, fusion_out, 
    we need to wrap it to only return the fusion_out. Since this is what we want to visualize.
    """
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def forward(self, image, audio):
        _, _, fusion_out = self.base_model(audio, image) # Assuming the model takes audio first and then image
        return fusion_out

central_wrapper = FusionOnlyWrapper(model)
pca_plot_path = "centralnet_results/plots/"

_ = pca_plot_dataloaders(central_wrapper, testdata, selected_digits=[5, 8], dirpath=pca_plot_path, 
                         show_plots=False, is_dino_based=False)
_ = pca_plot_multiclass(central_wrapper, testdata, selected_digits=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 
                        dirpath=pca_plot_path, show_plots=False, is_dino_based=False)
_ = tsne_plot_multiclass(central_wrapper, testdata, selected_digits=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 
                         dirpath=pca_plot_path, show_plots=False, random_seed=0, is_dino_based=False)
    