In [25]:
import os
# Per avere una traccia più precisa dell'errore
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import math
import numpy as np

from scipy import signal
from scipy.fft import fft, fftshift

import torch
import torch.nn as nn
from torch.nn import functional as F

import cv2
import pickle
from PIL import Image
import matplotlib.pyplot as plt

import IPython
from tqdm import tqdm

In [26]:
import librosa

In [27]:
from transformers import ViTModel

In [28]:
from transformers import ViTConfig

In [29]:
import librosa

In [30]:
!nvidia-smi

Thu Jul  6 15:50:35 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:1A:00.0 Off |                  N/A |
| 30%   24C    P8    22W / 250W |   5166MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:1B:00.0 Off |                  N/A |
| 30%   24C    P8     6W / 250W |      8MiB / 11019MiB |      0%      Default |
|       

In [31]:
CUDA_VISIBLE_DEVICES=6
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [32]:
class ResidualConvTranspose2d(nn.Module):

    def __init__(self, device='cpu'):
        super(ResidualConvTranspose2d, self).__init__()

        self.convTranspose1 = nn.ConvTranspose2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1).to(device)

        self.convTranspose2 = nn.ConvTranspose2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1).to(device)

    def forward(self, inputs):
        outputConvTranspose1 = self.convTranspose1(inputs)
        outputConvTranspose2 = self.convTranspose1(outputConvTranspose1)
        return inputs + outputConvTranspose2


In [33]:
class ResidualBlock(nn.Module):

    def __init__(self, device='cpu'):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=(2,1), stride=(2,1)), 
                        nn.GELU(), 
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=1, stride=1)
        ).to(device) 
        
        self.ext_block = nn.Sequential(
                        nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(3,3), padding=(2,1), stride=(2,1))
        ).to(device) 
        
    def forward(self, inputs):
        extended_input = self.ext_block(inputs)
        convolved_input = self.block(inputs)
        return convolved_input + extended_input

In [34]:
class GenerativeNetwork(nn.Module):
    """
    Input Shape: (b, 1, 640, 640)
    Output ViT: (b, num_pathces, hidden_size)
    After Reshape: (b, 1, x, x) dove x è una dimensione generica che puoi decidere
    Ouput Shape: (b, 1, 1024, 800)
    """
    
    def __init__(self, device='cpu'):
        super(GenerativeNetwork, self).__init__()
        self.device = device
        
        self.hidden_size = 64
        self.patch_size = 32
        configuration = ViTConfig(num_hidden_layers=8, 
                                  num_attention_heads=16, 
                                  hidden_size=self.hidden_size, 
                                  patch_size=self.patch_size, 
                                  num_channels=1, 
                                  image_size=640)
        
        self.vit = ViTModel(configuration).to(self.device)
        self.refine_model = nn.Sequential(
            
                        nn.ConvTranspose2d(in_channels=400, out_channels=1024, kernel_size=5, padding=2, stride=2, output_padding=1), 
                        nn.LeakyReLU(negative_slope=0.2), #
                        nn.ConvTranspose2d(in_channels=1024, out_channels=1024, kernel_size=5, padding=2, stride=2, output_padding=1), 
                        nn.LeakyReLU(negative_slope=0.2), #

                        ResidualConvTranspose2d(device),
                        nn.LeakyReLU(negative_slope=0.2),
                        ResidualConvTranspose2d(device),
                        nn.LeakyReLU(negative_slope=0.2),
                        ResidualConvTranspose2d(device),
                        
                        nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1),
                        nn.Conv2d(in_channels=1024, out_channels=640, kernel_size=1)
        ).to(device)
        
        
    def forward(self, inputs):
        if inputs.device == 'cpu':
            inputs = inputs.to(self.device)
        batch_size = inputs.shape[0]
        vit_res = self.vit(pixel_values=inputs)
        inputs = vit_res.last_hidden_state[:, 1:, :]
        inputs = inputs.view(batch_size, -1, 8, 8)
        result = self.refine_model(inputs)
        result = result.view(batch_size, 1, 1024, 640)
        return result

In [35]:

g = GenerativeNetwork(device)
res = g(torch.zeros((1,1,640,640)).to(device))
print(res.shape)


torch.Size([1, 1, 1024, 640])


In [36]:
class DiscriminativeNetwork(nn.Module): 
    """
    Input Shape: (b, 1, 1025, 800)
    Ouput Shape: (b, 1)
    """
    
    def __init__(self, device='cpu'):
        super(DiscriminativeNetwork, self).__init__()
        self.device = device
        self.classifier = nn.Sequential(
                                        nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(2),
                                        nn.Conv2d(in_channels=2, out_channels=4, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(4),
                                        nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(8),
                                        nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(16),
                                        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2),
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(32),
                                        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2), 
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(64),
                                        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2), 
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(128),
                                        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2), 
                                        nn.LeakyReLU(0.2),
                                        nn.BatchNorm2d(256),
                                        nn.Flatten(),
                                        nn.Dropout(0.3),
                                        nn.Linear(in_features=768, out_features=128),
                                        nn.LeakyReLU(0.2),
                                        nn.Dropout(0.3),
                                        nn.Linear(in_features=128, out_features=1),
                                        nn.Sigmoid()
                                        
                                        
        ).to(self.device)


    def forward(self, inputs):
        if inputs.device == 'cpu':
            inputs = inputs.to(self.device)
        return self.classifier(inputs)

In [37]:
class LHB_Dataset(torch.utils.data.Dataset):

    def __init__(self, path, ext):
        self.path = path
        self.ext = ext
        self.len = len(os.listdir(self.path))
        self.items_in_dir = os.listdir(self.path)


    def __len__(self):
        return self.len

    
    def __getitem__(self, idx):
       
        name = self.path + '/' + self.items_in_dir[idx] #self.path + '/' + str(idx) + "." + self.ext

        with open(name, 'rb') as fd:
            song = pickle.load(fd)

        return song #[:1321967]

In [38]:
test_path = "/home/simona/Adele/UnzippedDataset/test"

test_ds = LHB_Dataset(test_path, 'mus')

print(test_ds[0].shape)
print(len(test_ds))

(1321967,)
571


In [39]:
#test
test_generator = torch.Generator(device='cpu')
test_generator.manual_seed(13)
testloader = torch.utils.data.DataLoader(
                                            dataset=test_ds, 
                                            batch_size=1, 
                                            shuffle=False,
                                            generator=test_generator
                                        )

In [40]:
# Models
generator = GenerativeNetwork(device).to(device)

In [41]:
checkpoint = torch.load('./GEN_BestVal_ismisArch_smallerInput_test2805_v7000_1024out', map_location='cpu')
checkpoint

{'epoch': 84,
 'model_state_dict': OrderedDict([('vit.embeddings.cls_token',
               tensor([[[-0.0126, -0.0016, -0.0090, -0.0107, -0.0352, -0.0482,  0.0216,
                         -0.0522,  0.0596, -0.0389, -0.0194, -0.0337, -0.0408,  0.0173,
                         -0.0504, -0.0656,  0.0123, -0.0193, -0.0366,  0.0462, -0.0941,
                          0.0396, -0.0170, -0.0212,  0.0059, -0.0191, -0.0527, -0.0377,
                          0.0083, -0.0291, -0.0444,  0.0421,  0.0087, -0.0129,  0.0028,
                          0.0180, -0.0313,  0.0090, -0.0659,  0.0632,  0.0296, -0.0336,
                          0.0090,  0.0116,  0.0541,  0.0317,  0.0892,  0.0662,  0.0408,
                          0.0446, -0.0501,  0.0204, -0.0084,  0.0053,  0.0656,  0.0106,
                          0.0111, -0.0018,  0.0320,  0.0707,  0.0266,  0.0417,  0.0496,
                          0.0023]]])),
              ('vit.embeddings.position_embeddings',
               tensor([[[-0.0126, -0.00

In [42]:
generator.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [43]:
def compute_signal_to_noise(truth, reco):
    den = np.sqrt(np.sum((truth - reco)**2)) + 1e-6
    return 10.*np.log10(1e-6 + np.sqrt(np.sum(truth**2))/den)
    
def compute_signal_to_noise_pytorch(truth, reco):
    truth = truth.view(-1,truth.shape[2], truth.shape[3])
    reco = reco.view(-1,reco.shape[2], reco.shape[3])
    den = torch.sqrt(torch.sum(torch.pow((truth - reco), 2)))
    return torch.mean(10.0 * torch.log10(1e-6 + torch.sqrt(torch.sum(torch.pow(truth, 2)))/den))
    
def compute_lsd(truth, reco):
    true_X = np.log10(np.abs(truth)**2 + 1e-6)
    reco_X = np.log10(np.abs(reco)**2 + 1e-6)
    reco_X_diff_squared = (true_X - reco_X)**2
    return np.mean(np.sqrt(np.mean(reco_X_diff_squared, axis=0)))

def compute_lsd_pytorch(truth, reco):
    truth = truth.view(-1,truth.shape[2], truth.shape[3])
    reco = reco.view(-1,reco.shape[2], reco.shape[3])
    true_X = torch.log10(torch.pow(torch.abs(truth),2) + 1e-6)
    reco_X = torch.log10(torch.pow(torch.abs(reco),2) + 1e-6)
    diff = true_X - reco_X
    reco_X_diff_squared = torch.pow(diff, 2)
    return torch.mean(torch.mean(torch.sqrt(torch.mean(reco_X_diff_squared, dim=1)), dim=1))  
    
def get_metric_comparison(testloader, metric, generator=None, device='cpu'): 
    
    generator.eval()
    
    NUM_COLS = 640
    TOT_ROWS = 1664
    HF_ROWS = 1024
    LF_ROWS = 640
    
    
    total_value = 0.0
    count = 0

    for test_data in testloader:
        
        data = np.asarray(test_data).squeeze(axis=0)
        # Compute spectrograms
        stft = librosa.stft(np.asarray(data), n_fft=4096, win_length=4096, window=signal.windows.hamming(4096))
        spectrogram = librosa.amplitude_to_db(abs(stft))

        rows = spectrogram.shape[0]
        real_cols = spectrogram.shape[1]

        if real_cols % NUM_COLS > 0:
            cols_to_add = NUM_COLS - real_cols % NUM_COLS
            new_data = np.zeros(shape = (rows, real_cols + cols_to_add))
            new_data[:, : real_cols] = spectrogram
            new_data[:, real_cols : ] = spectrogram[:, -cols_to_add:]
            spectrogram = new_data
            cols = real_cols + cols_to_add
        else:
            cols = real_cols

        PTS = cols // NUM_COLS
        
        temp_data = np.zeros(shape=(PTS, HF_ROWS+LF_ROWS, NUM_COLS))
        for i in range(PTS):
            temp_data[i, :, :] = spectrogram[ : HF_ROWS+LF_ROWS, i*NUM_COLS : (i+1)*NUM_COLS]

        temp_data = torch.from_numpy(temp_data).view(PTS, 1, -1, NUM_COLS).float()
        ds_lf = temp_data[:, :, : LF_ROWS, :]
        ds_hf = temp_data[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :] 

        ds_lf = ds_lf.to(device)

        generated_hf = np.asarray(generator(ds_lf).detach().cpu())
        ds_hf = np.asarray(ds_hf.detach().cpu())
        ds_lf = np.asarray(ds_lf.detach().cpu())

        tmp_real = np.zeros(shape=(PTS, 1, HF_ROWS+LF_ROWS, NUM_COLS))
        tmp_real[:, :, : LF_ROWS, :] = ds_lf
        tmp_real[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :] = ds_hf
        
        tmp_pred = np.zeros(shape=(PTS, 1, HF_ROWS+LF_ROWS, NUM_COLS))
        tmp_pred[:, :, : LF_ROWS, :] = ds_lf
        tmp_pred[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :] = generated_hf

        real_spec = np.zeros(shape=(TOT_ROWS, NUM_COLS*PTS))
        pred_spec = np.zeros(shape=(TOT_ROWS, NUM_COLS*PTS))

        for j in range(PTS):
            real_spec[ : TOT_ROWS, j*NUM_COLS : (j+1)*NUM_COLS] = tmp_real[j, :, :, :]

            pred_spec[ : TOT_ROWS, j*NUM_COLS : (j+1)*NUM_COLS] = tmp_pred[j, :, :, :]

        real_spec = real_spec[:, :real_cols]
        pred_spec = pred_spec[:, :real_cols]
           
        
        total_value = total_value + metric(real_spec, pred_spec)
        count = count+1
        
    return total_value / count


def get_metric_comparison_onlyLB(testloader, metric, device='cpu'): 
    NUM_COLS = 640
    TOT_ROWS = 1664
    HF_ROWS = 1024
    LF_ROWS = 640
    
    total_value = 0.0
    count = 0

    for test_data in testloader:
        
        data = np.asarray(test_data).squeeze(axis=0)
        # Compute spectrograms
        stft = librosa.stft(np.asarray(data), n_fft=4096, win_length=4096, window=signal.windows.hamming(4096))
        spectrogram = librosa.amplitude_to_db(abs(stft))

        rows = spectrogram.shape[0]
        real_cols = spectrogram.shape[1]

        if real_cols % NUM_COLS > 0:
            cols_to_add = NUM_COLS - real_cols % NUM_COLS
            new_data = np.zeros(shape = (rows, real_cols + cols_to_add))
            new_data[:, : real_cols] = spectrogram
            new_data[:, real_cols : ] = spectrogram[:, -cols_to_add:]
            spectrogram = new_data
            cols = real_cols + cols_to_add
        else:
            cols = real_cols

        PTS = cols // NUM_COLS
        
        temp_data = np.zeros(shape=(PTS, HF_ROWS+LF_ROWS, NUM_COLS))
        for i in range(PTS):
            temp_data[i, :, :] = spectrogram[ : HF_ROWS+LF_ROWS, i*NUM_COLS : (i+1)*NUM_COLS]
            
        temp_data = torch.from_numpy(temp_data).view(PTS, 1, -1, NUM_COLS).float()

        ds_lf = temp_data[:, :, : LF_ROWS , :] 
        ds_hf = temp_data[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :]

        min_value = torch.min(ds_hf)
        generated_hf = torch.ones_like(ds_hf) * min_value
        ds_hf = np.asarray(ds_hf.detach().cpu())
        ds_lf = np.asarray(ds_lf.detach().cpu())

        tmp_real = np.zeros(shape=(PTS, 1, HF_ROWS+LF_ROWS, NUM_COLS))
        tmp_real[:, :, : LF_ROWS, :] = ds_lf
        tmp_real[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :] = ds_hf
        
        tmp_pred = np.zeros(shape=(PTS, 1, HF_ROWS+LF_ROWS, NUM_COLS))
        tmp_pred[:, :, : LF_ROWS, :] = ds_lf
        tmp_pred[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :] = generated_hf

        real_spec = np.zeros(shape=(TOT_ROWS, NUM_COLS*PTS))
        pred_spec = np.zeros(shape=(TOT_ROWS, NUM_COLS*PTS))

        for j in range(PTS):
            real_spec[ : TOT_ROWS , j*NUM_COLS : (j+1)*NUM_COLS] = tmp_real[j, :, :, :]
            pred_spec[ : TOT_ROWS , j*NUM_COLS : (j+1)*NUM_COLS] = tmp_pred[j, :, :, :]

        real_spec = real_spec[:, :real_cols]
        pred_spec = pred_spec[:, :real_cols]
        
        total_value = total_value + metric(real_spec, pred_spec)
        count = count+1
        
    return total_value / count

            
def get_metric_comparison_on_interpolation(testloader, metric): 
    NUM_COLS = 640
    TOT_ROWS = 1664
    HF_ROWS = 1024
    LF_ROWS = 640
    
    total_value = 0.0
    count = 0

    for test_data in testloader:
        
        data = np.asarray(test_data).squeeze(axis=0)
        # Compute spectrograms
        stft = librosa.stft(np.asarray(data), n_fft=4096, win_length=4096, window=signal.windows.hamming(4096))
        spectrogram = librosa.amplitude_to_db(abs(stft))

        rows = spectrogram.shape[0]
        real_cols = spectrogram.shape[1]

        if real_cols % NUM_COLS > 0:
            cols_to_add = NUM_COLS - real_cols % NUM_COLS 
            new_data = np.zeros(shape = (rows, real_cols + cols_to_add))
            new_data[:, : real_cols] = spectrogram
            new_data[:, real_cols : ] = spectrogram[:, -cols_to_add:]
            spectrogram = new_data
            cols = real_cols + cols_to_add
        else:
            cols = real_cols

        PTS = cols // NUM_COLS
        
        temp_data = np.zeros(shape=(PTS, HF_ROWS+LF_ROWS, NUM_COLS))
        for i in range(PTS):
            temp_data[i, :, :] = spectrogram[ : HF_ROWS+LF_ROWS, i*NUM_COLS : (i+1)*NUM_COLS]

        temp_data = temp_data.reshape(PTS, 1, LF_ROWS+HF_ROWS, NUM_COLS)
        ds_lf = temp_data[:, :, : LF_ROWS, :]
        ds_hf = temp_data[:, :, LF_ROWS : LF_ROWS+HF_ROWS, :] 
        
        dim = (NUM_COLS, HF_ROWS+LF_ROWS)
        tmp_hf = ds_hf.copy().reshape(ds_hf.shape[0], ds_hf.shape[2], ds_hf.shape[3])
        tmp_pred = np.zeros(shape=(PTS, HF_ROWS+LF_ROWS, NUM_COLS))
        for i in range(PTS):
            tmp_pred[i] = cv2.resize(tmp_hf[i], dim, interpolation=cv2.INTER_CUBIC)

        tmp_real = np.zeros(shape=(PTS, 1, HF_ROWS+LF_ROWS, NUM_COLS))
        tmp_real[:, :, :LF_ROWS, :] = ds_lf
        tmp_real[:, :, LF_ROWS:LF_ROWS+HF_ROWS, :] = ds_hf

        real_spec = np.zeros(shape=(TOT_ROWS, NUM_COLS*PTS))
        pred_spec = np.zeros(shape=(TOT_ROWS, NUM_COLS*PTS))

        for j in range(PTS):
            real_spec[ : LF_ROWS+HF_ROWS, j*NUM_COLS : (j+1)*NUM_COLS] = tmp_real[j, :, :, :]
            pred_spec[ : LF_ROWS+HF_ROWS, j*NUM_COLS : (j+1)*NUM_COLS] = tmp_pred[j, :, :]

        real_spec = real_spec[:, :real_cols]
        pred_spec = pred_spec[:, :real_cols]
        
        total_value = total_value + metric(real_spec, pred_spec)
        count = count+1
        
    return total_value / count


In [44]:
#TEST GENERATOR
print('LSD: ', get_metric_comparison(testloader=testloader, metric=compute_lsd, generator=generator, device=device))
print('SNR: ', get_metric_comparison(testloader=testloader, metric=compute_signal_to_noise, generator=generator, device=device))

LSD:  0.5047658298916666
SNR:  5.261366655041625


In [45]:
#TEST ONLY LB
print('LSD: ', get_metric_comparison_onlyLB(testloader=testloader, metric=compute_lsd, device=device))
print('SNR: ', get_metric_comparison_onlyLB(testloader=testloader, metric=compute_signal_to_noise, device=device))

LSD:  0.6423481186869183
SNR:  3.1647817379036964


In [46]:
#TEST INTERPOLATION
print('LSD: ', get_metric_comparison_on_interpolation(testloader=testloader, metric=compute_lsd))
print('SNR: ', get_metric_comparison_on_interpolation(testloader=testloader, metric=compute_signal_to_noise))

LSD:  0.9560718635992412
SNR:  1.5952223733848003


In [47]:
gain = (0.8825366152647948 - 0.51646231465691) / 0.8825366152647948
print(gain)

0.41479786138736974


In [48]:
gain = (3.968512690039224 - 2.2772004729735045) / 3.968512690039224
print(gain)

0.4261828924752671
