In [1]:
# To hold the code of a blank transformer model
import sys
sys.path.append('/workspace/fourth_year_project/MusicGen')
#print(sys.path)

In [2]:
from MyAudioDataset import MyAudioDataset
from AudioCodesDataset import AudioCodesDataset

In [3]:
mydataset = MyAudioDataset('/workspace/small_model_data2', 'recording_01_')

In [4]:
audio_codes_dataset = AudioCodesDataset(mydataset)

In [5]:
#audio_codes_dataset.run_compression()

In [6]:
#audio_codes_dataset.save_data('90_degree_compress_tensors.pkl')

In [7]:
len(audio_codes_dataset.data_map)

570

In [3]:
temp_empty = AudioCodesDataset()

In [4]:
temp_empty.load_data('90_degree_compress_tensors.pkl')

In [6]:
import torch
from torch import nn
import torchaudio
import torch.nn.functional as F

In [7]:

class AudioTransformer(nn.Module):
    def __init__(self, comp_model, d_model, nhead, num_layers, dim_feedforward, compute_seperate_loss=False):
        super(AudioTransformer, self).__init__()
        self.transformer = nn.Transformer(d_model, nhead, num_layers, dim_feedforward)
        self.input_encoding = nn.Linear(1500, d_model)  # input of mono audio
        #self.angle_encoding = nn.Linear(6, d_model)  # add this once 90 works. 
        self.output_decoding = nn.Linear(d_model, 1500)  # Decoding back to stereo audio
        
        self.comp_model = comp_model.cuda()
        self.compute_seperate_loss = compute_seperate_loss

    def forward(self, orig, target, angle):
        orig = orig.cuda()
        target = target.cuda()
        #print(audio.shape)
        orig = self.input_encoding(orig)
        target = self.input_encoding(target)
        # Relu to rid negatives
        orig = F.relu(orig)
        target = F.relu(target)

        #angle = self.angle_encoding(angle)  # Process one-hot encoded angle
        #angle = angle.unsqueeze(1).repeat(1, audio.size(2), 1)  # Repeat angle for each time step
        #x = audio + angle  # Combine audio and angle

        x = self.transformer(src=orig, tgt=target)
        x = F.relu(x)

        x = self.output_decoding(x)
        x = F.relu(x)

        # Scale back to integers
        x = x * 1000
        x = torch.round(x)

        return x

    def compress(self, stereo):
            stereo = stereo.cuda()
            with torch.no_grad():
                stereo, scale = self.comp_model.encode(stereo)
            return stereo


    def decompress(self, stereo):
            stereo = stereo.cuda()
            with torch.no_grad():
                stereo = self.comp_model.decode(stereo)
            return stereo

    
    def train(self, dataset, batch_size=1, epochs=1, lr=0.001, cosine_loss=False):
        if not cosine_loss:
            loss_fn = nn.MSELoss()
        else:
            loss_fn = nn.CosineEmbeddingLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        

        for epoch in range(epochs):
            for i, (target, orig, angle, sr) in enumerate(train_loader):
                optimizer.zero_grad()
                target = target.cuda()
                orig = orig.cuda()
                # Convert wav to codes
                target_codes = self.compress(target)
                orig_codes = self.compress(orig)

                #print(type(target_codes[0][0]))
                #print(type(orig_codes[0][0]))
                # Pass codes to model
                output = self(orig=orig_codes.float(), target=target_codes.float(), angle=angle)
                #output = output.squeeze(0)
                #target = target_codes.squeeze(0)
                if cosine_loss:
                    if self.compute_seperate_loss:
                        total_loss = 0
                        for j in range(target_codes.shape[1]):
                            output_j = output[:, j, :]
                            target_codes_j = target_codes[:, j, :]
                            # Should be a 1D tensor of length equal to the batch size
                            # So if 4 then should be 4 ones
                            y = torch.ones(target_codes_j.shape[0]).cuda()
                            #print(output_j.shape, target_codes_j.shape, y.shape)
                            loss = loss_fn(output_j, target_codes_j, y)
                            total_loss += loss
                        total_loss /=  target_codes.shape[1]  # average the loss over the batch
                        total_loss.backward()  # backpropagate the average loss
                        optimizer.step()
                        scheduler.step()
                        print(f"Epoch {epoch}, batch {i}, loss {total_loss}")
                    else:
                        output = output.view(-1, 1500)
                        target_codes = target_codes.view(-1, 1500)
                        y = torch.ones((output.size(0),)).cuda()
                        loss = loss_fn(output, target_codes, y)
                        loss.backward()
                        optimizer.step()
                        scheduler.step()
                        print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")
                else:
                    # Using MSE loss
                    loss = loss_fn(output, target_codes)
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")
                    


                #y = torch.ones_like(target_codes).cuda()
                #y = torch.ones(target_codes.size(0)).cuda()
                
                
                #print(output.shape, target_codes.shape, y.shape)
                #loss = loss_fn(output, target_codes, y)
                #loss.backward()
                #optimizer.step()
                #print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")

            if epoch % 5 == 0:
                torch.save(self.state_dict(), f"model_{epoch}.pth")
                print(f"Saved model_{epoch}.pth")

        print("Finished Training")

In [9]:
from audiocraft.models import CompressionModel
from audiocraft.models.encodec import InterleaveStereoCompressionModel
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
#model = model.cuda()
comp_model = InterleaveStereoCompressionModel(model)
# move to GPU
##comp_model = comp_model.cuda()

In [10]:
myTransformer = AudioTransformer(comp_model=comp_model, d_model=512, nhead=4, num_layers=3, dim_feedforward=256).cuda()



In [11]:
myTransformer.train(dataset=mydataset, batch_size=4, epochs=10, lr=0.01)

Epoch 0, batch 0, loss 0.5078215599060059
Epoch 0, batch 1, loss 0.5051887035369873
Epoch 0, batch 2, loss 0.509951114654541
Epoch 0, batch 3, loss 0.5086799263954163
Epoch 0, batch 4, loss 0.5129799842834473
Epoch 0, batch 5, loss 0.5045744180679321
Epoch 0, batch 6, loss 0.509512186050415
Epoch 0, batch 7, loss 0.5075318217277527
Epoch 0, batch 8, loss 0.5118129849433899
Epoch 0, batch 9, loss 0.504594087600708
Epoch 0, batch 10, loss 0.5094484090805054
Epoch 0, batch 11, loss 0.5002940893173218
Epoch 0, batch 12, loss 0.5055699348449707
Epoch 0, batch 13, loss 0.5077252388000488
Epoch 0, batch 14, loss 0.5081615447998047
Epoch 0, batch 15, loss 0.5095643401145935
Epoch 0, batch 16, loss 0.49246591329574585
Epoch 0, batch 17, loss 0.49862605333328247
Epoch 0, batch 18, loss 0.504978358745575
Epoch 0, batch 19, loss 0.5092415809631348
Epoch 0, batch 20, loss 0.48742619156837463
Epoch 0, batch 21, loss 0.5113459825515747
Epoch 0, batch 22, loss 0.5109686255455017
Epoch 0, batch 23, los

In [18]:
torch.ones(4)

tensor([1., 1., 1., 1.])