In [1]:
# To hold the code of a blank transformer model
import sys
sys.path.append('/workspace/fourth_year_project/MusicGen')
#print(sys.path)

In [2]:
from MyAudioDataset import MyAudioDataset
from AudioCodesDataset import AudioCodesDataset

In [3]:
from audiocraft.models import CompressionModel
from audiocraft.models.encodec import InterleaveStereoCompressionModel
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
#model = model.cuda()
comp_model = InterleaveStereoCompressionModel(model).cuda()
# move to GPU
##comp_model = comp_model.cuda()



In [4]:
mydataset = MyAudioDataset('/workspace/small_model_data3', 'recording_01_')

In [5]:
audio_codes_dataset = AudioCodesDataset(comp_model=comp_model, dataset=mydataset)

In [6]:
audio_codes_dataset.run_compression()

In [8]:
print(audio_codes_dataset.data_map[0]['original'].shape)
print(audio_codes_dataset.data_map[0]['original_norm'].shape)

torch.Size([8, 1500])
torch.Size([8, 1500])


In [9]:
audio_codes_dataset.data_map[0]['original_norm']

tensor([[-1.8611, -1.8611, -1.8611,  ..., -1.0767, -1.0767, -1.0767],
        [-1.8611, -1.8611, -1.8611,  ..., -1.0767, -1.0767, -1.0767],
        [ 0.3033,  1.0248,  1.0248,  ...,  0.3646,  0.3646, -2.0135],
        ...,
        [-1.0437,  0.4118,  0.9556,  ...,  1.1112,  0.5155,  0.9321],
        [-0.5454,  1.1678,  0.9446,  ...,  0.8676,  0.9446,  0.8676],
        [-0.5454,  1.1678,  0.9446,  ...,  0.8676,  0.9446,  0.8676]],
       device='cuda:0')

In [10]:
len(audio_codes_dataset.data_map)

570

In [11]:
audio_codes_dataset.save_data('90_degree_compress_tensors.pkl')

In [10]:
len(audio_codes_dataset.data_map)

570

In [5]:
audio_codes_dataset = AudioCodesDataset(comp_model)

In [6]:
audio_codes_dataset.load_data('90_degree_compress_tensors.pkl')

In [7]:
audio_codes_dataset.data_map[0]['original_norm'].shape

torch.Size([8, 1500])

In [8]:
import torch
from torch import nn
import torchaudio
import torch.nn.functional as F

In [9]:

class AudioTransformer(nn.Module):
    def __init__(self, comp_model, d_model, nhead, num_layers, dim_feedforward, compute_seperate_loss=False):
        super(AudioTransformer, self).__init__()
        self.input_encoding = nn.Linear(1500, d_model)  # input audio
        self.input_bn = nn.BatchNorm1d(d_model)

        self.transformer = nn.Transformer(d_model, nhead, num_layers, dim_feedforward)
        self.hidden_layer = nn.
        self.output_bn = nn.BatchNorm1d(d_model)
        self.output_decoding = nn.Linear(d_model, 1500)  # Decoding back to stereo audio
        #self.angle_encoding = nn.Linear(6, d_model)  # add this once 90 works. 
        
        if comp_model is not None:
            self.comp_model = comp_model.cuda()
        else:
            self.comp_model = None
        self.compute_seperate_loss = compute_seperate_loss

    # Orig and target are the normalized values of the codes
    def forward(self, orig, target, angle):
        orig = orig.cuda()
        target = target.cuda()
        #print(audio.shape)
        orig = self.input_encoding(orig)
        target = self.input_encoding(target)
        # Relu to rid negatives
        #orig = F.relu(orig)
        #target = F.relu(target)

        #angle = self.angle_encoding(angle)  # Process one-hot encoded angle
        #angle = angle.unsqueeze(1).repeat(1, audio.size(2), 1)  # Repeat angle for each time step
        #x = audio + angle  # Combine audio and angle

        x = self.transformer(src=orig, tgt=target)
        #x = F.relu(x)

        x = self.output_decoding(x)
        x = F.relu(x)

        # Scale back to integers
        #x = x * 1000
        #x = torch.round(x)

        return x

    def compress(self, stereo):
            if self.comp_model is None:
                raise Exception("No compression model found")
            stereo = stereo.cuda()
            with torch.no_grad():
                stereo, scale = self.comp_model.encode(stereo)
            return stereo


    def decompress(self, stereo):
            if self.comp_model is None:
                raise Exception("No compression model found")
            stereo = stereo.cuda()
            with torch.no_grad():
                stereo = self.comp_model.decode(stereo)
            return stereo
    
    def compute_mean_std(self):
        all_data = torch.cat([item['target'] for item in self.data_map] + [item['original'] for item in self.data_map])
        mean = torch.mean(all_data)
        std = torch.std(all_data)
        return mean, std

    
    
    def train_loop(self, dataset, batch_size=1, epochs=1, lr=0.001, cosine_loss=False):
        if not cosine_loss:
            loss_fn = nn.MSELoss()
        else:
            loss_fn = nn.CosineEmbeddingLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        

        for epoch in range(epochs):
            #for i, (target, target_norm, orig, orig_norm, angle, sr) in enumerate(train_loader):
            for i, (_, target, _, orig, angle, sr) in enumerate(train_loader):
                optimizer.zero_grad()
                target = target.cuda()
                orig = orig.cuda()
                # Convert wav to codes
                target_codes = target
                orig_codes = orig
                #print(target_codes.shape, orig_codes.shape)
                #target_codes = self.compress(target)
                #orig_codes = self.compress(orig)

                #print(type(target_codes[0][0]))
                #print(type(orig_codes[0][0]))
                # Pass codes to model
                output = self(orig=orig_codes.float(), target=target_codes.float(), angle=angle)
                #output = output.squeeze(0)
                #target = target_codes.squeeze(0)
                if cosine_loss:
                    if self.compute_seperate_loss:
                        total_loss = 0
                        for j in range(target_codes.shape[1]):
                            output_j = output[:, j, :]
                            target_codes_j = target_codes[:, j, :]
                            # Should be a 1D tensor of length equal to the batch size
                            # So if 4 then should be 4 ones
                            y = torch.ones(target_codes_j.shape[0]).cuda()
                            #print(output_j.shape, target_codes_j.shape, y.shape)
                            loss = loss_fn(output_j, target_codes_j, y)
                            total_loss += loss
                        total_loss /=  target_codes.shape[1]  # average the loss over the batch
                        total_loss.backward()  # backpropagate the average loss
                        optimizer.step()
                        scheduler.step()
                        print(f"Epoch {epoch}, batch {i}, loss {total_loss}")
                    else:
                        output = output.view(-1, 1500)
                        target_codes = target_codes.view(-1, 1500)
                        y = torch.ones((output.size(0),)).cuda()
                        loss = loss_fn(output, target_codes, y)
                        loss.backward()
                        optimizer.step()
                        scheduler.step()
                        print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")
                else:
                    # Using MSE loss
                    output = output.float()
                    target_codes = target_codes.float()
                    loss = loss_fn(output, target_codes)

                    #print(output, target_codes)
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    if i % 10 == 0:
                        print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")
                    


                #y = torch.ones_like(target_codes).cuda()
                #y = torch.ones(target_codes.size(0)).cuda()
                
                
                #print(output.shape, target_codes.shape, y.shape)
                #loss = loss_fn(output, target_codes, y)
                #loss.backward()
                #optimizer.step()
                #print(f"Epoch {epoch}, batch {i}, loss {loss.item()}")

            if epoch % 5 == 0 and epoch != 0:
                torch.save(self.state_dict(), f"model_{epoch}.pth")
                print(f"Saved model_{epoch}.pth")

        print("Finished Training")

In [10]:
myTransformer = AudioTransformer(comp_model=comp_model, d_model=512, nhead=4, num_layers=3, dim_feedforward=256).cuda()
myTransformer.train()



AudioTransformer(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-255): 256 x TransformerDecoderLayer(
          (self_attn): MultiheadAt

In [11]:
myTransformer.train_loop(dataset=audio_codes_dataset, batch_size=8, epochs=100, lr=0.001)

Epoch 0, batch 0, loss 1.1546118259429932
Epoch 0, batch 10, loss 0.9905588030815125
Epoch 0, batch 20, loss 0.9900474548339844
Epoch 0, batch 30, loss 0.9915304183959961
Epoch 0, batch 40, loss 0.9881868958473206
Epoch 0, batch 50, loss 0.9917421936988831
Epoch 0, batch 60, loss 0.9907203912734985
Epoch 0, batch 70, loss 0.9899943470954895
Epoch 1, batch 0, loss 0.9878147840499878
Epoch 1, batch 10, loss 0.9897182583808899
Epoch 1, batch 20, loss 0.9901285171508789
Epoch 1, batch 30, loss 0.9897047281265259
Epoch 1, batch 40, loss 0.9892682433128357
Epoch 1, batch 50, loss 0.9896485805511475
Epoch 1, batch 60, loss 0.9895039200782776
Epoch 1, batch 70, loss 0.9866693615913391
Epoch 2, batch 0, loss 0.9900076389312744
Epoch 2, batch 10, loss 0.9885878562927246
Epoch 2, batch 20, loss 0.989489734172821
Epoch 2, batch 30, loss 0.9887732863426208
Epoch 2, batch 40, loss 0.9908292293548584
Epoch 2, batch 50, loss 0.9901090264320374
Epoch 2, batch 60, loss 0.9891948103904724
Epoch 2, batch 

KeyboardInterrupt: 