In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip ./drive/MyDrive/vit_sr/pickled.zip -d .

!rm ./UnzippedDataset/train/137.mus
!rm ./UnzippedDataset/train/899.mus
!rm ./UnzippedDataset/train/1194.mus
!rm ./UnzippedDataset/train/462.mus

In [None]:
!pip install transformers

In [None]:
import torch
import torch.nn as nn
from transformers import ViTConfig, ViTModel, AdamW

import os
import math
import glob
import pickle


import librosa
from scipy import signal

import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, device='cpu'):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1), 
                        nn.GELU(), 
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1)
        ).to(device) 

        
    def forward(self, inputs):
        convolved_input = self.block(inputs)
        return convolved_input + inputs

In [None]:
class GenerativeNetwork(nn.Module):
    
    def __init__(self, device='cpu'):
        super(GenerativeNetwork, self).__init__()
        self.device = device
        self.hidden_size = 64
        self.patch_size = 16
        configuration = ViTConfig(num_attention_heads=8, num_hidden_layers=8, hidden_size=self.hidden_size, patch_size=self.patch_size, num_channels=1, image_size=1024)
        self.vit = ViTModel(configuration).to(self.device)
        self.model = nn.Sequential(
                        # bring the image back to the original size
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, padding=1, stride=2), 
                        nn.GELU(), 
                      
                        # skip connections
                        ResidualBlock(),
                        nn.GELU(),                      
                        ResidualBlock(),
                        nn.GELU(),
                        ResidualBlock(),
                        nn.GELU(), 
                        ResidualBlock(),
                        nn.GELU(),  
        ).to(device)
        

    def patch_to_img(self, x, patch_size):
        B, NumPatches, HiddenSize = x.shape
        x = x.reshape(B, NumPatches, 1, HiddenSize)
        x = x.reshape(B, NumPatches, 1, patch_size, patch_size)
        x = x.permute(0, 1, 3, 4, 2)
        x = x.reshape(B, int(math.sqrt(NumPatches)), int(math.sqrt(NumPatches)), patch_size, patch_size, 1)
        x = x.permute(0,1,3,2,4,5)
        new_h = x.shape[1] * x.shape[2]
        new_w = x.shape[3] * x.shape[4]
        x = x.reshape(B, new_h, new_w, 1)
        x = x.swapaxes(3, 1)
        x = x.swapaxes(3, 2)
        return x
    
        
    def forward(self, inputs):
        if inputs.device == 'cpu':
            inputs = inputs.to(self.device)
        vit_res = self.vit(pixel_values=inputs)
        inputs = vit_res.last_hidden_state[:, 1:, :]
        patch_size_after_vit = int(math.sqrt(inputs.shape[2]))
        inputs = self.patch_to_img(inputs, patch_size_after_vit)
        return self.model(inputs)

In [None]:
class LHB_Dataset(torch.utils.data.Dataset):

    def __init__(self, path, ext):
        self.path = path
        self.ext = ext
        self.len = len(os.listdir(self.path))
        self.items_in_dir = os.listdir(self.path)


    def __len__(self):
        return self.len

    
    def __getitem__(self, idx):
       
        name = self.path + '/' + self.items_in_dir[idx] 

        with open(name, 'rb') as fd:
            song = pickle.load(fd)

        return song[:1318970]

In [None]:
train_path = './UnzippedDataset/train'

train_ds = LHB_Dataset(train_path, 'mus')

print(train_ds[0].shape)
print(len(train_ds))

In [None]:
#train
train_generator = torch.Generator(device='cpu')
train_generator.manual_seed(13)
trainloader = torch.utils.data.DataLoader(
                                            dataset=train_ds, 
                                            batch_size=1, 
                                            shuffle=True,
                                            generator=train_generator
                                        )

In [None]:
generator = GenerativeNetwork(device)
optimizer_gen = AdamW(generator.parameters(), lr=1e-4) 
loss_gen = nn.MSELoss()

In [None]:
import datetime
def save_model(model, path):
    if not os.path.exists(path):
        os.makedirs(path)
    filename = path + '/generator_' + str(datetime.datetime.now().strftime("%d-%m-%Y_%H-%M-%S")) + '.pt'
    torch.save(model.state_dict(), filename)

In [None]:
def train(generator,epochs,train_loader):
    i = 0
    for i in range(epochs):
      i += 1
      history = []
      print(f"Start epoch {i}")
      total_loss = 0
      k = 0
      for data_batch in train_loader:

        batch_lb = []
        batch_hb = []
      
        for data in data_batch:

          data = data.squeeze(dim=0)  

          train_stft = librosa.stft(np.asarray(data), n_fft=4096, win_length=4096, window=signal.windows.hamming(4096))
          train_spectrogram = torch.tensor(librosa.amplitude_to_db(abs(train_stft)))
          train_spectrogram = (train_spectrogram - train_spectrogram.min())/(train_spectrogram.max()-train_spectrogram.min())

          lb = train_spectrogram[1:1025,:1024]
          hb = train_spectrogram[1025:,:1024]

          lb = lb.reshape(1,1024,1024)
          hb = hb.reshape(1,1024,1024)

          batch_lb.append(lb)
          batch_hb.append(hb)
        
        batch_lb = torch.stack(batch_lb).to(device)
        batch_hb = torch.stack(batch_hb).to(device)

        gen_hb = generator(batch_lb).to(device)

        optimizer_gen.zero_grad()
        loss = loss_gen(gen_hb, batch_hb)
        total_loss += loss.detach()
        k+=1
        loss.backward()
        optimizer_gen.step()

        print(f'Loss: {loss.item()}')
        history.append(loss.item())

      total_loss = total_loss / k
      if i% 40 == 0:
        save_model(generator,"models")

      plt.plot(history,label="loss")

      plt.show()
      
      print("Mean loss"+str(total_loss))

In [None]:
train(generator, 100, trainloader)