<a href="https://colab.research.google.com/github/smhall97/hallucinating_GANs/blob/main/Combined_CNN_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
# @title Imports

import os
import glob
import torch
import time
import copy

import numpy as np
import matplotlib.pyplot as plt
import pickle
from tqdm.notebook import tqdm

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.optim import lr_scheduler

In [15]:
# @title Set device (GPU or CPU)
# NMA code
# inform the user if the notebook uses GPU or CPU.

def set_device():
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("WARNING: For this notebook to perform best, "
        "if possible, in the menu under `Runtime` -> "
        "`Change runtime type.`  select `GPU` ")
  else:
    print("GPU is enabled in this notebook.")

  return device

device = set_device()



In [16]:
# @title Helper Functions and Loaders

def scale_minmax(X):

    X_scaled = (X - X.min()) / (X.max() - X.min())

    return X_scaled


def pickle_loader_mel(file):
  with open(file, 'rb') as f:
      data = pickle.load(f)
      data = np.transpose(data, axes=[1, 2, 0])
      data = scale_minmax(data)

      zeros = np.zeros(data.shape)

      data = np.concatenate((data, data, data), axis=2)

  return(data)


def pickle_loader_stft_real(file):
  with open(file, 'rb') as f:
      data = pickle.load(f)
      data = np.squeeze(data, axis=0)

      #real and imaginary parts are scaled independently
      data[:,:,0] = scale_minmax(data[:,:,0])
      data[:,:,1] = scale_minmax(data[:,:,1])

      real = data[:,:,0].unsqueeze(2)

      zeros = np.zeros((data.shape[0],data.shape[1],1))
      data = np.concatenate((real, real, real), axis=2)

  return(data)


def pickle_loader_stft(file):
  with open(file, 'rb') as f:
      data = pickle.load(f)
      data = np.squeeze(data, axis=0)

      #real and imaginary parts are scaled independently
      data[:,:,0] = scale_minmax(data[:,:,0])
      data[:,:,1] = scale_minmax(data[:,:,1])

      zeros = np.zeros((data.shape[0],data.shape[1],1))
      data = np.concatenate((data, zeros), axis=2)

  return(data)


def make_sets(classes, items_per_class, ratios):
  """
  parameters:
  classes: number of classes in dataset
  items_per_class: elements per class (assumes that the dataset is balanced across classes)
  ratios: list or array with ratios for each subset [ratio_trainining, ratio_validation, ratio_test]
  """

  train_size = ratios[0] * items_per_class
  val_size = ratios[1] * items_per_class
  test_size = ratios[2] * items_per_class

  test_ix, val_ix, train_ix = np.array([]),np.array([]),np.array([])

  for i in range(classes): 
    class_ix = items_per_class * i
    
    train_ix = np.append(train_ix, np.arange(train_size) + class_ix)
    val_ix = np.append(val_ix, np.arange(train_size, train_size + val_size) + class_ix)
    test_ix = np.append(test_ix, np.arange(train_size + val_size, train_size + val_size + test_size) + class_ix)

  subsets = {
          'train': torch.utils.data.Subset(dataset, train_ix.astype(int)),
          'val': torch.utils.data.Subset(dataset, val_ix.astype(int)),
          'test': torch.utils.data.Subset(dataset, test_ix.astype(int))
          }

  return subsets


def get_cfg_transform(t, augment):

  if t == 'stft_r':
    params = '{}_{}'.format(str(n_fft), str(hop_length))
    pickle_loader = pickle_loader_stft_real
    dims = (513, 2580)

  if t == 'stft':
    params = '{}_{}'.format(str(n_fft), str(hop_length))
    pickle_loader = pickle_loader_stft
    dims = (513, 2580)
  
  elif t == 'mel':
    params = '{}_{}_{}'.format(str(n_fft), str(hop_length), n_mels)
    pickle_loader = pickle_loader_mel
    dims = (128, 2580)
  else:
    print('ERROR: Unkown transform')

  #Set the pytorch transforms for the dataloaders (not the same as the mathematical transforms from before)
  if augment:
    data_transforms = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.CenterCrop(dims),
                                      transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5])
                                      ])
  else:
    data_transforms = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.CenterCrop(dims),
                                      transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5])
                                      ])
  


  return params, pickle_loader, data_transforms



In [17]:
# @title Train model function from PyTorch

# Original code from this tutorial: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_acc_list, val_acc_list = [], []
    # train_loss, validation_loss = [], []
    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            num_examples = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                num_examples += inputs.size(0)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs.float())
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()
            # Different to tutorial, hardcoded dataset size
            # print(dataset_sizes) from above 
            
            #if phase == 'train':
            #    num_examples = len(train_loader) * len(next(iter(train_loader))[0]) 
            #else:
            #    num_examples = len(val_loader) * len(next(iter(val_loader))[0])
            print('number of examples in loader = ', num_examples)
            print(f'RUNNING LOSS: {running_loss}, RUNNING CORRECTS: {running_corrects}')

            epoch_loss = running_loss / num_examples
            print()
            epoch_acc = running_corrects.double() / num_examples
            if phase == 'train':
              train_acc_list.append(epoch_acc)
            else:
              val_acc_list.append(epoch_acc)
          
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    # model = model.to('cuda')
    return model, train_acc_list, val_acc_list

In [18]:
# @title Parameters

#general paramenters
mount_drive = True
model_name = 'vgg16'
transforms_list = ['mel', 'stft', 'stft_r'] #transforms to be trained on. Takes 'mel', 'stft' or 'stft_r'
augment = False #If data augmentation should be performed
path = '/content/drive/MyDrive/HallucinatingGANs/Code/data/'
outpath = '/content/drive/MyDrive/HallucinatingGANs/Code/data/models/'

#audio transforms parrameters
n_fft = 1024
n_mels = 128
hop_length = 256 # smaller hop size leads to better reconstruction but takes longer to compute
power = 2.0 # squared power spectrogram
samplerate =  22050

#model training parameters
epochs = 40
n_workers = 4
minibatch_size = 4


In [19]:
# @title Mount Google Drive

if mount_drive:
  from google.colab import drive
drive.mount('/content/drive') #it will ask you for a verification code

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# @title Main Loop

for transform in transforms_list:
  #get settings for the transform to be performed
  params, pickle_loader, data_transforms = get_cfg_transform(transform, augment)

  #path of the corresponding spectrograms
  data_dir = os.path.join(os.path.abspath(path), 'spectrograms', transform, params)
  print(data_dir)

  #names of ouput files
  if augment:
    label = '{}AUG_{}_()'.format(transform, params)
  else:
    label = '{}_{}_()'.format(transform, params)

  acc_file = outpath + label + '.pkl'
  model_file = outpath + label + '.pt'

  #load dataset
  dataset = torchvision.datasets.DatasetFolder(root=data_dir,
                                              transform = data_transforms, 
                                              loader=pickle_loader, 
                                              extensions='.pkl', 
                                              )
  #get genres and number of classes in the dataset
  genres = list(os.listdir(data_dir))
  n_classes = len(genres) 

  #generate training, validation and test sets
  subsets = make_sets(classes=n_classes, 
                                        items_per_class=100,
                                        ratios=[.8, .1, .1])

  #create dataloaders
  dataloaders = {x: torch.utils.data.DataLoader(subsets[x], batch_size=minibatch_size,
                                              shuffle=True, num_workers=n_workers)
                for x in ['train', 'val']}

  # Load pretrained VGG
  # code extracted from:
  # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor
  # https://pytorch.org/vision/stable/models.html

  vgg16 = models.vgg16(pretrained=True)

  # Freeze the network except the last layer / unfreeze layers to allow finetuning
  for param in vgg16.parameters():
      param.requires_grad = True # If True it will train

  # Parameters of newly constructed modules have requires_grad=True by default
  # Add on classifier

  vgg16.classifier[6] = nn.Sequential(
                        nn.Linear(vgg16.classifier[3].in_features, 256),
                        nn.ReLU(), 
                        nn.Linear(256, n_classes),                   
                        nn.LogSoftmax(dim=1))

  criterion = nn.CrossEntropyLoss()

  # Observe that only parameters of final layer are being optimized as
  # opposed to before.
  optimizer_conv = optim.SGD(vgg16.parameters(), lr=0.001, momentum=0.9)

  # Decay LR by a factor of 0.1 every 7 epochs
  exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

  # model trains - change num_epochs to increase training time
  vgg16 = vgg16.float()
  model_ft, train_acc_list, val_acc_list = train_model(vgg16.to(device), criterion, optimizer_conv, exp_lr_scheduler,
                        num_epochs=epochs)

  #save accuracies from training procedure
  acc_dict = {'train_acc': train_acc_list, 'val_acc': val_acc_list}
  print('writing file: ' + filename)
  with open(acc_file, 'wb') as f:
      pickle.dump(acc_dict, f, pickle.HIGHEST_PROTOCOL)

  #save model
  torch.save(model_ft.state_dict(), model_file)


/content/drive/MyDrive/HallucinatingGANs/Code/data/spectrograms/mel/1024_256_128


  cpuset_checked))
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 0/39
----------


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
