In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip ./drive/MyDrive/vit_sr/pickled.zip -d .

In [None]:
!pip install transformers

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
import os
import math
import numpy as np

from scipy import signal
from scipy.fft import fft, fftshift

import torch
import torch.nn as nn
from torch.nn import functional as F

import cv2
import pickle
from PIL import Image
import matplotlib.pyplot as plt

import IPython

In [None]:
from transformers import ViTModel, ViTConfig, AdamW

import librosa
import librosa.display as display

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, device='cpu'):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1, stride=2), 
                        nn.GELU(), 
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1, stride=1)
        ).to(device) 
        
        self.ext_block = nn.Sequential(
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1, stride=2)
        ).to(device) 
        
    def forward(self, inputs):
        extended_input = self.ext_block(inputs)
        convolved_input = self.block(inputs)
        return convolved_input + extended_input

In [None]:
class GenerativeNetwork(nn.Module):
    
    def __init__(self, device='cpu'):
        super(GenerativeNetwork, self).__init__()
        self.device = device
        self.hidden_size = 4
        self.patch_size = 16
        configuration = ViTConfig(num_attention_heads=4, num_hidden_layers=8, hidden_size=self.hidden_size, patch_size=self.patch_size, num_channels=1, image_size=1024)
        self.vit = ViTModel(configuration).to(self.device)
        self.model = nn.Sequential(
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=0, stride=2), 
                        nn.GELU(), 
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1, stride=1), 
                        nn.GELU(),
                      
                        ResidualBlock(),
                        nn.GELU(),                      
                        ResidualBlock(),
                        nn.GELU(), 
                      
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(2,2), padding=1, stride=1), 
                        nn.GELU()
        ).to(device)
        
    
    def patch_to_img(self, x, patch_size):
        B, NumPatches, HiddenSize = x.shape
        x = x.reshape(B, NumPatches, 1, HiddenSize)
        x = x.reshape(B, NumPatches, 1, patch_size, patch_size)
        x = x.permute(0, 1, 3, 4, 2)
        x = x.reshape(B, int(math.sqrt(NumPatches)), int(math.sqrt(NumPatches)), patch_size, patch_size, 1)
        x = x.permute(0,1,3,2,4,5)
        new_h = x.shape[1] * x.shape[2]
        new_w = x.shape[3] * x.shape[4]
        x = x.reshape(B, new_h, new_w, 1)
        x = x.swapaxes(3, 1)
        x = x.swapaxes(3, 2)
        return x
    
        
    def forward(self, inputs):
        if inputs.device == 'cpu':
            inputs = inputs.to(self.device)
        vit_res = self.vit(pixel_values=inputs)
        inputs = vit_res.last_hidden_state[:, 1:, :]
        patch_size_after_vit = int(math.sqrt(inputs.shape[2]))
        inputs = self.patch_to_img(inputs, patch_size_after_vit)
        return self.model(inputs)

In [None]:
class DiscriminativeNetwork(nn.Module): 
    
    def __init__(self, device='cpu'):
        super(DiscriminativeNetwork, self).__init__()
        self.device = device
        self.classifier = nn.Sequential(
                                        nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(2),
                                        nn.Conv2d(in_channels=2, out_channels=4, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(4),
                                        nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(8),
                                        nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(16),
                                        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(32),
                                        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2), #3x1
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(64),
                                        nn.Flatten(),
                                        nn.Dropout(0.3),
                                        nn.Linear(in_features=14400, out_features=1024),
                                        nn.LeakyReLU(0.2),
                                        nn.Dropout(0.3),
                                        nn.Linear(in_features=1024, out_features=128),
                                        nn.LeakyReLU(0.2),
                                        nn.Dropout(0.3),
                                        nn.Linear(in_features=128, out_features=1),
                                        nn.Sigmoid()
                                        
                                        
        ).to(self.device)


    def forward(self, inputs):
        if inputs.device == 'cpu':
            inputs = inputs.to(self.device)
        return self.classifier(inputs)

In [None]:
class LHB_Dataset(torch.utils.data.Dataset):

    def __init__(self, path, ext):
        self.path = path
        self.ext = ext
        self.len = len(os.listdir(self.path))
        self.items_in_dir = os.listdir(self.path)


    def __len__(self):
        return self.len

    
    def __getitem__(self, idx):
       
        name = self.path + '/' + self.items_in_dir[idx]

        with open(name, 'rb') as fd:
            song = pickle.load(fd)

        return song[:1318970]

In [None]:
train_path = './UnzippedDataset/train'

train_ds = LHB_Dataset(train_path, 'mus')

print(train_ds[0].shape)
print(len(train_ds))

In [None]:
#train
train_generator = torch.Generator(device='cpu')
train_generator.manual_seed(13)
trainloader = torch.utils.data.DataLoader(
                                            dataset=train_ds, 
                                            batch_size=3, 
                                            shuffle=True,
                                            generator=train_generator
                                        )

In [None]:
# Models
generator = GenerativeNetwork(device).to(device)
discriminator = DiscriminativeNetwork(device).to(device)

# Optimizers
optimizer_gen = AdamW(generator.parameters(), lr=1e-4, weight_decay=1e-4) 
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=1e-7) 

# Loss
loss_gen = nn.MSELoss()
loss_dis = nn.BCELoss()

In [None]:
import datetime
def save_model(model, path):
    if not os.path.exists(path):
        os.makedirs(path)
    filename = path + '/generator_' + str(datetime.datetime.now().strftime("%d-%m-%Y_%H-%M-%S")) + '.pt'
    torch.save(model.state_dict(), filename)

In [None]:
def train(trainloader, generator, discriminator, optimizer_gen, optimizer_dis, loss_gen, loss_dis, epoches=1, beta=1.0, device='cpu'): 
    
    filename = 'reshapeAfterVit_V1.txt'
    
    alpha = 1.5
    
    NUM_COLS = 1024

    # TrainSteps
    for epoch in range(epoches):
        print('EPOCH: ', epoch)
        num_samples_seen = 0
        total_gen_loss = 0
        total_dis_loss = 0
        
        # Iter on batches
        for data_batch in trainloader: 
            print('START BATCH PROCESS')
            batch_lf = []
            batch_hf = []
            
            for data in data_batch:
                data = data.squeeze(dim=0) 

                
                train_stft = librosa.stft(np.asarray(data), n_fft=4096, win_length=4096, window=signal.windows.hamming(4096))
                train_spectrogram = torch.tensor(librosa.amplitude_to_db(abs(train_stft)))

                rows = train_spectrogram.shape[0]
                cols = train_spectrogram.shape[1]

                    
                train_spectrogram = train_spectrogram.reshape(1, rows, cols).float()

                
                batch_lf.append(train_spectrogram[:,1:1025,:1024])
                batch_hf.append(train_spectrogram[:,1025:,:1024])

                num_samples_seen += 1

            
             
            batch_lf = torch.stack(batch_lf).to(device)
            batch_hf = torch.stack(batch_hf).to(device)
            
            
            shuffled_indexes = np.random.permutation(batch_lf.shape[0]) #shuffle
            batch_lf = batch_lf[shuffled_indexes]
            batch_hf = batch_hf[shuffled_indexes]
                        
            # Train the discriminator on the true/generated data
            generated_data = generator(batch_lf)
            combined_data = torch.cat((batch_hf.to(device), generated_data.detach()), dim=0)          
            labels = torch.cat((torch.ones(batch_hf.shape[0]), torch.zeros(generated_data.shape[0])), dim=0)
            
            shuffled_indexes = np.random.permutation(combined_data.shape[0]) #shuffle
            combined_data = combined_data[shuffled_indexes]
            labels = labels[shuffled_indexes].to(device)

            optimizer_dis.zero_grad()
            discriminator_out = discriminator(combined_data).reshape(-1)
            discriminator_loss = loss_dis(discriminator_out, labels)
            print("Discriminator "+str(discriminator_loss.item()))
            discriminator_loss.backward()
            optimizer_dis.step()
            
            # Train the generator
            optimizer_gen.zero_grad()
            generator_out = generator(batch_lf)
            generator_loss = loss_gen(batch_hf, generator_out)
            
            discriminator_out_gen = discriminator(generator_out).reshape(-1)
            discriminator_loss_gen = loss_dis(discriminator_out_gen.to('cpu'), torch.ones(size=(discriminator_out_gen.shape[0],))) #bce
                        
            total_dis_loss = total_dis_loss + discriminator_loss_gen.detach()
            total_gen_loss = total_gen_loss + generator_loss.detach()

            print("Generator content "+str(generator_loss.item()))
            print("Generator adv "+str(discriminator_loss_gen.item()))

            loss = alpha*generator_loss + beta*discriminator_loss_gen 

            loss.backward()
            optimizer_gen.step()
            
        # End Trainloader Loop
        

        mean_gen_loss = total_gen_loss / num_samples_seen
        mean_dis_loss = total_dis_loss / num_samples_seen

        gen_order = torch.floor(torch.log10(mean_gen_loss))
        dis_order = 0 if mean_dis_loss == 0 else torch.floor(torch.log10(mean_dis_loss))
        b_pow = gen_order - dis_order 
        if b_pow > 0:
            b_pow = b_pow
        beta = pow(10.0, b_pow)
        
        save_model(generator)
        file = open(filename, 'a')
        file.write(
        'EPOCH ' + str(epoch+1) +
        '\n\t -> Discriminative Loss during D Training = ' + str(mean_dis_loss.item()) + ', during G Training = ' + str(discriminator_loss_gen.item()) +
        '\n\t -> Generative Loss = ' + str(loss.item()) + ' ---> alpha * ' + str(mean_gen_loss.item()) + ' beta * ' + str(mean_dis_loss.item()))
        file.flush()
        file.close()     
        
        print('EPOCH ' + str(epoch+1) +
        '\n\t -> Discriminative Loss during D Training = ' + str(mean_dis_loss.item()) + ', during G Training = ' + str(discriminator_loss_gen.item()) +
        '\n\t -> Generative Loss = ' + str(loss.item()) + ' ---> alpha * ' + str(mean_gen_loss.item()) + ' beta * ' + str(mean_dis_loss.item()))

In [None]:
train(trainloader, generator, discriminator, optimizer_gen, optimizer_dis, loss_gen, loss_dis, epoches=1, beta=1.0, device=device)