In [1]:
# To hold the code of a blank transformer model
import sys
sys.path.append('/workspace/fourth_year_project/MusicGen')
#print(sys.path)

In [2]:
from MyAudioDataset import MyAudioDataset

In [3]:
import torch
from torch import nn


In [4]:

class AudioTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward):
        super(AudioTransformer, self).__init__()
        self.transformer = nn.Transformer(d_model, nhead, num_layers, dim_feedforward)
        self.input_encoding = nn.Linear(1, d_model)  # input of mono audio
        #self.angle_encoding = nn.Linear(6, d_model)  # add this once 90 works. 
        self.output_decoding = nn.Linear(d_model, 2)  # Decoding back to stereo audio

    def forward(self, audio, angle):
        # audio: (batch_size, 1, seq_length)
        # angle: (batch_size, 6)  # one-hot encoded angle
        audio = self.input_encoding(audio)

        #angle = self.angle_encoding(angle)  # Process one-hot encoded angle
        #angle = angle.unsqueeze(1).repeat(1, audio.size(2), 1)  # Repeat angle for each time step
        #x = audio + angle  # Combine audio and angle

        x = self.transformer(x.permute(2, 0, 1))  # Transformer expects input in the form (seq_length, batch_size, d_model)
        x = x.permute(1, 2, 0)  # Bring it back to (batch_size, d_model, seq_length)
        x = self.output_decoding(x)
        return x.view(-1, 2, 30*44100)  # Reshape to stereo audio
    
    def train(self, dataset, batch_size, epochs, lr):
        # dataset: MyAudioDataset object
        # batch_size: int
        # epochs: int
        # lr: float
        #loss_fn = nn.MSELoss()
        loss_fn = nn.CosineEmbeddingLoss
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            for i, (audio, angle) in enumerate(train_loader):
                optimizer.zero_grad()
                output = self(audio, angle)
                loss = loss_fn(output, audio)
                loss.backward()
                optimizer.step()
                print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")

            torch.save(self.state_dict(), f"model_{epoch}.pth")
            print(f"Saved model_{epoch}.pth")

        print("Finished Training")