In [1]:
import sys
sys.path.append('/workspace/fourth_year_project/HRTF Models/')

from HRIRDataset import HRIRDataset
from MainModel import MainModel
from AutoregressiveModel import AutoregressiveModel
import matplotlib.pyplot as plt
from MaskModel import MaskModel
from SeqModel import SeqModel
sofa_file = '/workspace/fourth_year_project/HRTF Models/sofa_hrtfs/RIEC_hrir_subject_001.sofa'
hrir_dataset = HRIRDataset()
for i in range(1,100):
    hrir_dataset.load(sofa_file.replace('001', str(i).zfill(3)))
len(hrir_dataset)
from torch.utils.data import DataLoader
import torch
# model = MainModel()
model = MaskModel()
# Set the model to training mode
model.train()
num_epochs = 100

# Create the DataLoader
#dataloader = DataLoader(hrir_dataset, batch_size=32, shuffle=True)
device = torch.device('cuda')
model = model.to(device)



# Split the dataset into a training, validation and test set
# 0.8, 0.1, 0.1 respectively
train_size = int(0.7 * len(hrir_dataset))
val_size = int(0.2 * len(hrir_dataset))
test_size = len(hrir_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(hrir_dataset, [train_size, val_size, test_size])

# model.load_state_dict(torch.load('/workspace/fourth_year_project/HRTF Models/mask_models/model_4.pth'))
batch_size = 32

dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
target_folder = '/workspace/fourth_year_project/HRTF Models/mask_models/'
# Create it if it doesn't exist
import os
if not os.path.exists(target_folder):
    os.makedirs(target_folder)

def mask_values(tensor, mask_value, mask_prob):
    """
    Masks values in a tensor with `mask_value` with probability `mask_prob`.

    Args:
    tensor (torch.Tensor): The input tensor.
    mask_value (float): The value to use for masking.
    mask_prob (float): The probability of masking each value in the tensor.

    Returns:
    torch.Tensor: The masked tensor.
    torch.Tensor: The original tensor before masking.
    """
    # Create a mask tensor with the same size as the input tensor
    # The mask tensor has values of 1 where the input tensor is to be masked
    mask = torch.bernoulli(torch.full_like(tensor, mask_prob))

    # Create a masked tensor by replacing values where the mask is 1 with `mask_value`
    masked_tensor = tensor * (1 - mask) + mask * mask_value

    return masked_tensor, tensor

import torch
from torch import optim, nn
from torch.optim.lr_scheduler import StepLR

# learning_rate = 0.01
# Define the optimizer and loss function
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer = optim.Adam(model.parameters())
loss_function = nn.MSELoss(reduction='none')
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

# Set the model to training mode
model.train()
percent_masked = 0.1
factor = 0.003
# myshape = (32, 2, 512)
# Create a weight tensor that has higher values for the early part of the impulse response
# weights = torch.ones_like(myshape)
# weights[:, :200] *= 5  # Increase the weight for the first 200 samples
# # # move to cuda
# weights = weights.to(device)
# torch.autograd.set_detect_anomaly(True)
# Loop over each epoch

best_val_loss = 100000



In [2]:
for epoch in range(1, num_epochs):
    # Initialize the epoch loss
    epoch_loss = 0.0
    model.train()
    # Loop over each batch
    for i, batch in enumerate(dataloader):
        # Get the src and tgt sequences from the batch
        src, _, angle = batch
        
        src, true_values = mask_values(src, -2, percent_masked)
        mask = (src == -2).float()
        # Masked values are weighted 4 times more
        weights = mask * 3 + torch.ones_like(src)
        weights = weights.to(device)
        

        # Move data to the same device as the model
        src = src.to(device)
        angle = angle.to(device)
        true_values = true_values.to(device)
        # print(src.shape, tgt.shape, angle.shape)
        # convert to floats
        angle = angle.float()
        src = src.float()
        true_values = true_values.float()

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass through the model
        output = model(src, angle, true_values)
        
        # remove the last feature dimension from output
        # [batch_size, d_model, seq_length] --> [batch_size, d_model-1, seq_length]
        output = output[:, :-1, :]
        loss = loss_function(output, true_values)
        loss = loss * weights
        loss = loss.mean()
        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Accumulate the batch loss
        epoch_loss += loss.item()
    val_loss = 0
    scheduler.step()
    # Validate the model
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            src, _, angle = batch
            src = src.to(device)
            src, true_values = mask_values(src, -2, percent_masked)
            mask = (src == -2).float()
            # Masked values are weighted 10 times more
            weights = mask * 9 + torch.ones_like(src)
            weights = weights.to(device)
            angle = angle.to(device)
            angle = angle.float()
            src = src.float()
            true_values = true_values.to(device)
            true_values = true_values.float()
            output = model(src, angle, true_values)
            # remove the last feature dimension from output
            output = output[:, :-1, :]
            #print("Before loss val: ",output.shape, tgt.shape)
            loss = loss_function(output, true_values)
            loss = loss * weights
            loss = loss.mean()
            val_loss += loss.item()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        percent_masked += 0.001 # Add 1 %
    # Print the average loss for this epoch
    print(f'Epoch {epoch} | Training Loss: {epoch_loss / len(dataloader)} | Validation Loss: {val_loss / len(val_loader)} | Learning Rate: {scheduler.get_last_lr()} | Percentage Masked: {percent_masked} | Elements Masked: {int((percent_masked) * 512)}')
    if epoch % 2 == 0:
        torch.save(model.state_dict(), f'{target_folder}model_{epoch}.pth')

Epoch 1 | Training Loss: 0.4561279947657671 | Validation Loss: 0.4146750320053831 | Learning Rate: [0.001] | Percentage Masked: 0.101 | Elements Masked: 51
Epoch 2 | Training Loss: 0.1831001012937679 | Validation Loss: 0.15277886579105604 | Learning Rate: [0.001] | Percentage Masked: 0.10200000000000001 | Elements Masked: 52
Epoch 3 | Training Loss: 0.06362177504427685 | Validation Loss: 0.04996196721990957 | Learning Rate: [0.001] | Percentage Masked: 0.10300000000000001 | Elements Masked: 52
Epoch 4 | Training Loss: 0.021274773133446594 | Validation Loss: 0.01871472668576531 | Learning Rate: [0.001] | Percentage Masked: 0.10400000000000001 | Elements Masked: 53
Epoch 5 | Training Loss: 0.009784003837363947 | Validation Loss: 0.011501118172628168 | Learning Rate: [0.001] | Percentage Masked: 0.10500000000000001 | Elements Masked: 53
Epoch 6 | Training Loss: 0.00745084709873713 | Validation Loss: 0.010405823747309221 | Learning Rate: [0.001] | Percentage Masked: 0.10600000000000001 | E