In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/My Drive/DLProjects/JTM

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
/content/gdrive/My Drive/DLProjects/JTM


In [2]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-ebc587a8-e232-4538-79aa-4cf87e25d12b)


In [3]:
!pip install comet_ml



In [14]:
from comet_ml import Experiment

experiment = Experiment(
    api_key="K5CvVquVZJNg9xfY2ip95FuoD",
    project_name="jtm",
    workspace="tiagocuervo"
)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/tiagocuervo/jtm/48eea7c4b34f41059803188280ac26b7



# Dataloader

In [5]:
from pathlib import Path
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler, BatchSampler
import librosa
import tqdm
import random
import time
import pickle
import re
import numpy as np


class AudioBatchData(Dataset):

    def __init__(self,
                 rawAudioPath,
                 metadataPath,
                 sizeWindow,
                 labelsBy='composer',
                 outputPath=None,
                 CHUNK_SIZE=1e9,
                 NUM_CHUNKS_INMEM=2,
                 useGPU=False):
        """
        Args:
            - rawAudioPath (string): path to the raw audio files
            - metadataPath (string): path to the data set metadata (used to define labels)
            - sizeWindow (int): size of the sliding window
            - labelsBy (string): name of column in metadata according to which create labels
            - outputPath (string): path to the directory where chunks are to be created or are stored
            - CHUNK_SIZE (int): desired size in bytes of a chunk
            - NUM_CHUNKS_INMEM (int): target maximal size chunks of data to load in memory at a time
        """
        self.NUM_CHUNKS_INMEM = NUM_CHUNKS_INMEM
        self.CHUNK_SIZE = CHUNK_SIZE
        self.rawAudioPath = Path(rawAudioPath)
        self.sizeWindow = sizeWindow
        self.useGPU = useGPU

        self.sequencesData = pd.read_csv(metadataPath, index_col='id')
        self.sequencesData = self.sequencesData.sort_values(by=labelsBy)
        self.sequencesData[labelsBy] = self.sequencesData[labelsBy].astype('category')
        self.sequencesData[labelsBy] = self.sequencesData[labelsBy].cat.codes

        self.totSize = self.sequencesData['length'].sum()
        # print("Total size:", self.totSize)
        # print("Length of data set:", self.__len__())

        self.category = labelsBy

        if outputPath is None:
            self.chunksDir = self.rawAudioPath / labelsBy
        else:
            self.chunksDir = Path(outputPath) / labelsBy

        if not os.path.exists(self.chunksDir):
            os.makedirs(self.chunksDir)

        packages2Load = [fileName for fileName in os.listdir(self.chunksDir) if
                         re.match(r'chunk_.*[0-9]+.pickle', fileName)]

        if len(packages2Load) == 0:
            self._createChunks()
            packages2Load = [fileName for fileName in os.listdir(self.chunksDir) if
                             re.match(r'chunk_.*[0-9]+.pickle', fileName)]
        else:
            print("Chunks already exist at", self.chunksDir)

        self.packs = []
        packOfChunks = []
        for i, packagePath in enumerate(packages2Load):
            packOfChunks.append(packagePath)
            if (i + 1) % self.NUM_CHUNKS_INMEM == 0:
                self.packs.append(packOfChunks)
                packOfChunks = []
        if len(packOfChunks) > 0:
            self.packs.append(packOfChunks)

        self.currentPack = -1
        self.nextPack = 0
        self.sequenceIdx = 0

        self.data = None

        self._loadNextPack(first=True)
        self._loadNextPack()

    def _createChunks(self):
        print("Creating chunks at", self.chunksDir)
        pack = []
        packIds = []
        packageSize = 0
        packageIdx = 0
        for trackId in tqdm.tqdm(self.sequencesData.index):
            sequence, samplingRate = librosa.load(self.rawAudioPath / (str(trackId) + '.wav'), sr=16000)
            sequence = torch.tensor(sequence).float()
            packIds.append(trackId)
            pack.append(sequence)
            packageSize += len(sequence) * 4
            if packageSize >= self.CHUNK_SIZE:
                print(f"Saved pack {packageIdx}")
                with open(self.chunksDir / f'chunk_{packageIdx}.pickle', 'wb') as handle:
                    pickle.dump(torch.cat(pack, dim=0), handle, protocol=pickle.HIGHEST_PROTOCOL)
                with open(self.chunksDir / f'ids_{packageIdx}.pickle', 'wb') as handle:
                    pickle.dump(packIds, handle, protocol=pickle.HIGHEST_PROTOCOL)
                pack = []
                packIds = []
                packageSize = 0
                packageIdx += 1
        print(f"Saved pack {packageIdx}")
        with open(self.chunksDir / f'chunk_{packageIdx}.pickle', 'wb') as handle:
            pickle.dump(torch.cat(pack, dim=0), handle, protocol=pickle.HIGHEST_PROTOCOL)
        with open(self.chunksDir / f'ids_{packageIdx}.pickle', 'wb') as handle:
            pickle.dump(packIds, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def _loadNextPack(self, first=False):
        self.clear()
        if not first:
            self.currentPack = self.nextPack
            startTime = time.time()
            print('Loading files')
            self.categoryLabel = [0]
            packageIdx = [0]
            self.seqLabel = [0]
            packageSize = 0
            previousCategory = 0
            for packagePath in self.packs[self.currentPack]:
                with open(self.chunksDir / ('ids_' + packagePath.split('_', maxsplit=1)[-1]), 'rb') as handle:
                    chunkIds = pickle.load(handle)
                for seqId in chunkIds:
                    currentCategory = self.sequencesData.loc[seqId][self.category]
                    if currentCategory != previousCategory:
                        self.categoryLabel.append(packageSize)
                    previousCategory = currentCategory
                    packageSize += self.sequencesData.loc[seqId].length
                    self.seqLabel.append(packageSize)
                packageIdx.append(packageSize)

            self.data = torch.empty(size=(packageSize,))
            if self.useGPU:
                self.data.cuda()
            
            for i, packagePath in enumerate(self.packs[self.currentPack]):
                with open(self.chunksDir / packagePath, 'rb') as handle:
                    self.data[packageIdx[i]:packageIdx[i + 1]] = pickle.load(handle)
            print(f'Loaded {len(self.seqLabel) - 1} sequences, elapsed={time.time() - startTime:.3f} secs')

        self.nextPack = (self.currentPack + 1) % len(self.packs)
        if self.nextPack == 0 and len(self.packs) > 1:
            self.currentPack = -1
            self.nextPack = 0
            self.sequenceIdx = 0

    def clear(self):
        if 'data' in self.__dict__:
            del self.data
        if 'categoryLabel' in self.__dict__:
            del self.categoryLabel
        if 'seqLabel' in self.__dict__:
            del self.seqLabel

    def getCategoryLabel(self, idx):
        idCategory = next(x[0] for x in enumerate(self.categoryLabel) if x[1] > idx) - 1
        return idCategory

    def getSequenceLabel(self, idx):
        return self.categoryLabel[idx]

    def __len__(self):
        return self.totSize // self.sizeWindow

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.data) - self.sizeWindow - 1:
            print(idx)

        outData = self.data[idx:(self.sizeWindow + idx)].view(1, -1)
        label = torch.tensor(self.getCategoryLabel(idx), dtype=torch.long)
        return outData, label

    def getBaseSampler(self, samplingType, batchSize, offset):
        if samplingType == "samecategory":
            return SameTrackSampler(batchSize, self.categoryLabel, self.sizeWindow, offset)
        if samplingType == "samesequence":
            return SameTrackSampler(batchSize, self.seqLabel, self.sizeWindow, offset)
        if samplingType == "sequential":
            return SequentialSampler(len(self.data), self.sizeWindow, offset, batchSize)

        sampler = UniformAudioSampler(len(self.data), self.sizeWindow, offset)
        return BatchSampler(sampler, batchSize, True)

    def getDataLoader(self, batchSize, samplingType, randomOffset, numWorkers=0,
                      onLoop=-1):
        r"""
        Get a batch sampler for the current dataset.
            - batchSize (int): batch size
            - groupSize (int): in the case of type in ["track", "sequence"]
            number of items sharing a same label in the group
            (see AudioBatchSampler)
            - type (string):
                type == "track": grouped sampler track-wise
                type == "sequence": grouped sampler sequence-wise
                type == "sequential": sequential sampling
                else: uniform random sampling of the full audio
                vector
            - randomOffset (bool): if True add a random offset to the sampler
                                   at the begining of each iteration
        """
        nLoops = len(self.packs)
        totSize = self.totSize // (self.sizeWindow * batchSize)
        if onLoop >= 0:
            self.currentPack = onLoop - 1
            self._loadNextPack()
            nLoops = 1

        def samplerCall():
            offset = random.randint(0, self.sizeWindow // 2) \
                if randomOffset else 0
            return self.getBaseSampler(samplingType, batchSize, offset)

        return AudioLoader(self, samplerCall, nLoops, self._loadNextPack, totSize, numWorkers)


class AudioLoader(object):
    r"""
    A DataLoader meant to handle an AudioBatchData object.
    In order to handle big datasets AudioBatchData works with big chunks of
    audio it loads sequentially in memory: once all batches have been sampled
    on a chunk, the AudioBatchData loads the next one.
    """

    def __init__(self,
                 dataset,
                 samplerCall,
                 nLoop,
                 updateCall,
                 size,
                 numWorkers):
        r"""
        Args:
            - dataset (AudioBatchData): target dataset
            - samplerCall (function): batch-sampler to call
            - nLoop (int): number of chunks to load
            - updateCall (function): function loading the next chunk
            - size (int): total number of batches
            - numWorkers (int): see torch.utils.data.DataLoader
        """
        self.samplerCall = samplerCall
        self.updateCall = updateCall
        self.nLoop = nLoop
        self.size = size
        self.dataset = dataset
        self.numWorkers = numWorkers

    def __len__(self):
        return self.size

    def __iter__(self):

        for i in range(self.nLoop):
            sampler = self.samplerCall()
            dataloader = DataLoader(self.dataset,
                                    batch_sampler=sampler,
                                    num_workers=self.numWorkers)
            # print("Data loader nLoop: ", self.nLoop)
            # print("Len data loader: ", len(dataloader))
            # print("Len of sampler: ", len(sampler))
            # assert False
            # print("Dataloader len: \n", len(dataloader))
            for j, x in enumerate(dataloader):
                # print("Data loader yielded batch #: ", j)
                yield x
            # print("Len data loader: ", len(dataloader), "and consummed: ", j + 1)
            if i < self.nLoop - 1:
                self.updateCall()


class UniformAudioSampler(Sampler):

    def __init__(self,
                 dataSize,
                 sizeWindow,
                 offset):
        self.len = dataSize // sizeWindow
        self.sizeWindow = sizeWindow
        self.offset = offset
        if self.offset > 0:
            self.len -= 1

    def __iter__(self):
        return iter((self.offset
                     + self.sizeWindow * torch.randperm(self.len)).tolist())

    def __len__(self):
        return self.len


class SequentialSampler(Sampler):

    def __init__(self, dataSize, sizeWindow, offset, batchSize):

        self.len = (dataSize // sizeWindow) // batchSize
        self.sizeWindow = sizeWindow
        self.offset = offset
        self.startBatches = [x * (dataSize // batchSize)
                             for x in range(batchSize)]
        self.batchSize = batchSize
        if self.offset > 0:
            self.len -= 1

    def __iter__(self):
        for idx in range(self.len):
            yield [self.offset + self.sizeWindow * idx
                   + start for start in self.startBatches]

    def __len__(self):
        return self.len


class SameTrackSampler(Sampler):

    def __init__(self,
                 batchSize,
                 samplingIntervals,
                 sizeWindow,
                 offset):

        self.samplingIntervals = samplingIntervals
        self.sizeWindow = sizeWindow
        self.batchSize = batchSize
        self.offset = offset

        if self.samplingIntervals[0] != 0:
            raise AttributeError("Sampling intervals should start at zero")

        nWindows = len(self.samplingIntervals) - 1
        self.sizeSamplers = [(self.samplingIntervals[i + 1] -
                              self.samplingIntervals[i]) // self.sizeWindow
                             for i in range(nWindows)]  # How many windows a sequence/category lasts 

        # assert False
        if self.offset > 0:
            self.sizeSamplers = [max(0, x - 1) for x in self.sizeSamplers]
        # print("Size samplers:\n", self.sizeSamplers)
        # print("Size samplers over batch size:\n", np.array(self.sizeSamplers) // self.batchSize)

        order = [(x, torch.randperm(val).tolist())
                 for x, val in enumerate(self.sizeSamplers) if
                 val > 0]  # (index of seq/cat, randomly permuted numbers from 0 to num windows in seq(cat))

        # Build Batches
        self.batches = []
        for indexSampler, randperm in order:
            indexStart, sizeSampler = 0, self.sizeSamplers[indexSampler]
            while indexStart < (sizeSampler - self.batchSize):
                indexEnd = indexStart + self.batchSize
                locBatch = [self.getIndex(x, indexSampler)
                            for x in randperm[indexStart:indexEnd]]
                indexStart = indexEnd
                self.batches.append(locBatch)
        # print("Number of batches:\n", len(self.batches))
        # print("Batches:\n", self.batches)
        # print("Batches shape: \n", np.array(self.batches).shape)
        # print("Batches vstack shape: \n", np.vstack(self.batches).shape)
        self.batches = np.vstack(self.batches)

    def __len__(self):
        return len(self.batches)

    def getIndex(self, x, iInterval):
        return self.offset + x * self.sizeWindow + self.samplingIntervals[iInterval]

    def __iter__(self):
        random.shuffle(self.batches)
        return iter(self.batches)

# Model

In [6]:
import torch
import torch.nn as nn
import math


class ChannelNorm(nn.Module):

    def __init__(self,
                 numFeatures,
                 epsilon=1e-05,
                 affine=True):

        super(ChannelNorm, self).__init__()
        if affine:
            self.weight = nn.parameter.Parameter(torch.Tensor(1,
                                                              numFeatures, 1))
            self.bias = nn.parameter.Parameter(torch.Tensor(1, numFeatures, 1))
        else:
            self.weight = None
            self.bias = None
        self.epsilon = epsilon
        self.p = 0
        self.affine = affine
        self.reset_parameters()

    def reset_parameters(self):
        if self.affine:
            torch.nn.init.ones_(self.weight)
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):

        cumMean = x.mean(dim=1, keepdim=True)
        cumVar = x.var(dim=1, keepdim=True)
        x = (x - cumMean) * torch.rsqrt(cumVar + self.epsilon)

        if self.weight is not None:
            x = x * self.weight + self.bias
        return x


class SincConv1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, 
                 padding_mode='zeros', sampleRate=16000, minLowHz=50, minBandHz=50):
        super(SincConv1D, self).__init__()
        if in_channels != 1:
            msg = "SincConv1D only support one input channel (here, in_channels = {%i})" % (in_channels)
            raise ValueError(msg)
        self.outChannels = out_channels
        self.kernelSize = kernel_size
        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if self.kernelSize % 2 == 0:
            self.kernelSize += 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        if bias:
            raise ValueError('SincConv1D does not support bias.')
        if groups > 1:
            raise ValueError('SincConv1D does not support groups.')
        self.sampleRate = sampleRate
        self.minLowHz = minLowHz
        self.minBandHz = minBandHz
        # Initialize filterbanks such that they are equally spaced in Mel scale
        lowHz = 30
        highHz = self.sampleRate / 2 - (self.minLowHz + self.minBandHz)
        mel = np.linspace(self.hz2Mel(lowHz), self.hz2Mel(highHz), self.outChannels + 1)
        hz = self.mel2Hz(mel)
        # Filter lower frequency (outChannels, 1)
        self.lowHz_ = torch.nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
        # Filter frequency band (outChannels, 1)
        self.bandHz_ = torch.nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
        # Hamming window
        nLin= torch.linspace(0, (self.kernelSize / 2) - 1, 
                             steps=int((self.kernelSize / 2))) # computing only half of the window
        self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * nLin / self.kernelSize);
        n = (self.kernelSize - 1) / 2.0
        self.n_ = 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sampleRate # Due to symmetry, we only need half of the time axes
    
    @staticmethod
    def hz2Mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def mel2Hz(mel):
        return 700 * (10 ** (mel / 2595) - 1)

    def forward(self, waveforms):
        self.n_ = self.n_.to(waveforms.device)
        self.window_ = self.window_.to(waveforms.device)
        low = self.minLowHz  + torch.abs(self.lowHz_)
        high = torch.clamp(low + self.minBandHz + torch.abs(self.bandHz_), self.minLowHz, self.sampleRate/2)
        band = (high - low)[:, 0]
        fTimesTLow = torch.matmul(low, self.n_)
        fTimesTHigh = torch.matmul(high, self.n_)
        # Equivalent of Eq.4 of the reference paper
        bandPassLeft = ((torch.sin(fTimesTHigh) - torch.sin(fTimesTLow)) / (self.n_/2)) * self.window_ 
        bandPassCenter = 2 * band.view(-1, 1)
        bandPassRight = torch.flip(bandPassLeft, dims=[1])
        bandPass = torch.cat([bandPassLeft, bandPassCenter, bandPassRight], dim=1)
        bandPass = bandPass / (2 * band[:, None])
        self.filters = (bandPass).view(self.outChannels, 1, self.kernelSize)
        return torch.conv1d(waveforms, self.filters, stride=self.stride, padding=self.padding, 
                            dilation=self.dilation, bias=None, groups=1) 


class SincNetEncoder(nn.Module):
    def __init__(self,
                 sizeHidden=512,
                 normMode="layerNorm"):
        super(SincNetEncoder, self).__init__()     
        normLayer = ChannelNorm
        self.dimEncoded = sizeHidden
        self.conv0 = SincConv1D(1, sizeHidden, 10, stride=5, padding=3)
        self.batchNorm0 = normLayer(sizeHidden)



class CPCEncoder(nn.Module):

    def __init__(self,
                 sizeHidden=512,
                 normMode="layerNorm", sincNet=False):

        super(CPCEncoder, self).__init__()

        validModes = ["batchNorm", "instanceNorm", "ID", "layerNorm"]
        if normMode not in validModes:
            raise ValueError(f"Norm mode must be in {validModes}")

        if normMode == "instanceNorm":
            def normLayer(x):
                return nn.InstanceNorm1d(x, affine=True)
        elif normMode == "layerNorm":
            normLayer = ChannelNorm
        else:
            normLayer = nn.BatchNorm1d

        self.dimEncoded = sizeHidden
        if sincNet:
            self.conv0 = SincConv1D(1, sizeHidden, 10, stride=5, padding=3)
        else:
            self.conv0 = nn.Conv1d(1, sizeHidden, 10, stride=5, padding=3)
        self.batchNorm0 = normLayer(sizeHidden)
        self.conv1 = nn.Conv1d(sizeHidden, sizeHidden, 8, stride=4, padding=2)
        self.batchNorm1 = normLayer(sizeHidden)
        self.conv2 = nn.Conv1d(sizeHidden, sizeHidden, 4,
                               stride=2, padding=1)
        self.batchNorm2 = normLayer(sizeHidden)
        self.conv3 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1)
        self.batchNorm3 = normLayer(sizeHidden)
        self.conv4 = nn.Conv1d(sizeHidden, sizeHidden, 4, stride=2, padding=1)
        self.batchNorm4 = normLayer(sizeHidden)
        self.DOWNSAMPLING = 160

    def getDimOutput(self):
        return self.conv4.out_channels

    def forward(self, x):
        x = torch.relu(self.batchNorm0(self.conv0(x)))
        x = torch.relu(self.batchNorm1(self.conv1(x)))
        x = torch.relu(self.batchNorm2(self.conv2(x)))
        x = torch.relu(self.batchNorm3(self.conv3(x)))
        x = torch.relu(self.batchNorm4(self.conv4(x)))
        return x


class CPCAR(nn.Module):

    def __init__(self,
                 dimEncoded,
                 dimOutput,
                 keepHidden,
                 nLevelsGRU,
                 mode="GRU",
                 reverse=False):

        super(CPCAR, self).__init__()
        self.RESIDUAL_STD = 0.1

        if mode == "LSTM":
            self.baseNet = nn.LSTM(dimEncoded, dimOutput,
                                   num_layers=nLevelsGRU, batch_first=True)
        elif mode == "RNN":
            self.baseNet = nn.RNN(dimEncoded, dimOutput,
                                  num_layers=nLevelsGRU, batch_first=True)
        else:
            self.baseNet = nn.GRU(dimEncoded, dimOutput,
                                  num_layers=nLevelsGRU, batch_first=True)

        self.hidden = None
        self.keepHidden = keepHidden
        self.reverse = reverse

    def getDimOutput(self):
        return self.baseNet.hidden_size

    def forward(self, x):

        if self.reverse:
            x = torch.flip(x, [1])
        try:
            self.baseNet.flatten_parameters()
        except RuntimeError:
            pass
        x, h = self.baseNet(x, self.hidden)
        if self.keepHidden:
            if isinstance(h, tuple):
                self.hidden = tuple(x.detach() for x in h)
            else:
                self.hidden = h.detach()

        # For better modularity, a sequence's order should be preserved
        # by each module
        if self.reverse:
            x = torch.flip(x, [1])
        return x


class CPCModel(nn.Module):

    def __init__(self,
                 encoder,
                 AR):
        super(CPCModel, self).__init__()
        self.gEncoder = encoder
        self.gAR = AR

    def forward(self, batchData, label):
        encodedData = self.gEncoder(batchData).permute(0, 2, 1)
        cFeature = self.gAR(encodedData)
        return cFeature, encodedData, label


class PredictionNetwork(nn.Module):

    def __init__(self,
                 nPredicts,
                 dimOutputAR,
                 dimOutputEncoder,
                 dropout=False):

        super(PredictionNetwork, self).__init__()
        self.predictors = nn.ModuleList()
        self.RESIDUAL_STD = 0.01
        self.dimOutputAR = dimOutputAR

        self.dropout = nn.Dropout(p=0.5) if dropout else None
        for i in range(nPredicts):
            self.predictors.append(
                nn.Linear(dimOutputAR, dimOutputEncoder, bias=False))
            if dimOutputEncoder > dimOutputAR:
                residual = dimOutputEncoder - dimOutputAR
                self.predictors[-1].weight.data.copy_(torch.cat([torch.randn(
                    dimOutputAR, dimOutputAR), self.RESIDUAL_STD * torch.randn(residual, dimOutputAR)], dim=0))

    def forward(self, c, candidates):

        assert (len(candidates) == len(self.predictors))
        out = []

        # UGLY
        # if isinstance(self.predictors[0], EqualizedConv1d):
        # c = c.permute(0, 2, 1)

        for k in range(len(self.predictors)):

            locC = self.predictors[k](c)
            if isinstance(locC, tuple):
                locC = locC[0]
            # if isinstance(self.predictors[k], EqualizedConv1d):
            # locC = locC.permute(0, 2, 1)
            if self.dropout is not None:
                locC = self.dropout(locC)
            locC = locC.view(locC.size(0), 1, locC.size(1), locC.size(2))
            outK = (locC * candidates[k]).mean(dim=3)
            out.append(outK)
        return out


class BaseCriterion(nn.Module):
    def update(self):
        return


class CPCUnsupersivedCriterion(BaseCriterion):

    def __init__(self,
                 nPredicts,  # Number of steps
                 dimOutputAR,  # Dimension of G_ar
                 dimOutputEncoder,  # Dimension of the convolutional net
                 negativeSamplingExt,  # Number of negative samples to draw
                 mode=None,
                 dropout=False):

        super(CPCUnsupersivedCriterion, self).__init__()

        self.wPrediction = PredictionNetwork(
            nPredicts, dimOutputAR, dimOutputEncoder, dropout=dropout)
        self.nPredicts = nPredicts
        self.negativeSamplingExt = negativeSamplingExt
        self.lossCriterion = nn.CrossEntropyLoss()

        if mode not in [None, "reverse"]:
            raise ValueError("Invalid mode")

        self.mode = mode

    def sampleClean(self, encodedData, windowSize):

        batchSize, nNegativeExt, dimEncoded = encodedData.size()
        outputs = []

        negExt = encodedData.contiguous().view(-1, dimEncoded)
        # Draw nNegativeExt * batchSize negative samples anywhere in the batch
        batchIdx = torch.randint(low=0, high=batchSize,
                                 size=(self.negativeSamplingExt
                                       * windowSize * batchSize,),
                                 device=encodedData.device)

        seqIdx = torch.randint(low=1, high=nNegativeExt,
                               size=(self.negativeSamplingExt
                                     * windowSize * batchSize,),
                               device=encodedData.device)

        baseIdx = torch.arange(0, windowSize, device=encodedData.device)
        baseIdx = baseIdx.view(1, 1,
                               windowSize).expand(1,
                                                  self.negativeSamplingExt,
                                                  windowSize).expand(batchSize, self.negativeSamplingExt, windowSize)
        seqIdx += baseIdx.contiguous().view(-1)
        seqIdx = torch.remainder(seqIdx, nNegativeExt)

        extIdx = seqIdx + batchIdx * nNegativeExt
        negExt = negExt[extIdx].view(batchSize, self.negativeSamplingExt,
                                     windowSize, dimEncoded)

        labelLoss = torch.zeros((batchSize * windowSize),
                                dtype=torch.long,
                                device=encodedData.device)

        for k in range(1, self.nPredicts + 1):

            # Positive samples
            if k < self.nPredicts:
                posSeq = encodedData[:, k:-(self.nPredicts - k)]
            else:
                posSeq = encodedData[:, k:]

            posSeq = posSeq.view(batchSize, 1, windowSize, dimEncoded)
            fullSeq = torch.cat((posSeq, negExt), dim=1)
            outputs.append(fullSeq)

        return outputs, labelLoss

    def getInnerLoss(self):

        return "orthoLoss", self.orthoLoss * self.wPrediction.orthoCriterion()

    def forward(self, cFeature, encodedData):

        if self.mode == "reverse":
            encodedData = torch.flip(encodedData, [1])
            cFeature = torch.flip(cFeature, [1])

        batchSize, seqSize, dimAR = cFeature.size()
        windowSize = seqSize - self.nPredicts

        cFeature = cFeature[:, :windowSize]

        sampledData, labelLoss = self.sampleClean(encodedData, windowSize)

        predictions = self.wPrediction(cFeature, sampledData)

        outLosses = [0 for _ in range(self.nPredicts)]
        outAcc = [0 for _ in range(self.nPredicts)]

        for k, locPreds in enumerate(predictions[:self.nPredicts]):
            locPreds = locPreds.permute(0, 2, 1)  # (batchSize, 1 + negativeSamplingExt, windowSize) to
            #                                       (batchSize, windowSize, 1 + negativeSamplingExt)
            locPreds = locPreds.contiguous().view(
                -1, locPreds.size(2))  # (batchSize, windowSize, 1 + negativeSamplingExt) to
            #                            (batchSize * windowSize, 1 + negativeSamplingExt)
            lossK = self.lossCriterion(locPreds, labelLoss)
            outLosses[k] += lossK.view(1, -1)
            _, predsIndex = locPreds.max(1)
            outAcc[k] += torch.sum(predsIndex == labelLoss).float().view(1, -1)

        return torch.cat(outLosses, dim=1), torch.cat(outAcc, dim=1) / (windowSize * batchSize)

## Transformers

In [7]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self,
                 sizeSeq,         # Size of the input sequence
                 dk,              # Dimension of the input sequence
                 dropout,         # Dropout parameter
                 relpos=False):   # Do we retrieve positional information ?
        super(ScaledDotProductAttention, self).__init__()

        self.drop = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=2)
        self.relpos = relpos
        self.sizeSeq = sizeSeq

        if relpos:
            self.Krelpos = nn.Parameter(torch.Tensor(dk, sizeSeq))
            self.initmat_(self.Krelpos)
            self.register_buffer('z', torch.zeros(1, sizeSeq, 1))

        # A mask is set so that a node never queries data in the future
        mask = torch.tril(torch.ones(sizeSeq, sizeSeq), diagonal=0)
        mask = 1 - mask
        mask[mask == 1] = -float('inf')
        self.register_buffer('mask', mask.unsqueeze(0))

    def initmat_(self, mat, dim=0):
        stdv = 1. / math.sqrt(mat.size(dim))
        mat.data.uniform_(-stdv, stdv)

    def forward(self, Q, K, V):
        # Input dim : N x sizeSeq x dk
        QK = torch.bmm(Q, K.transpose(-2, -1))

        if self.relpos:
            bsz = Q.size(0)
            QP = Q.matmul(self.Krelpos)
            # This trick with z fills QP's diagonal with zeros
            QP = torch.cat((self.z.expand(bsz, -1, -1), QP), 2)
            QK += QP.view(bsz, self.sizeSeq + 1, self.sizeSeq)[:, 1:, :]
        A = self.softmax(QK / math.sqrt(K.size(-1)) + self.mask)
        return torch.bmm(self.drop(A), V)


class MultiHeadAttention(nn.Module):
    def __init__(self,
                 sizeSeq,   # Size of a sequence
                 dropout,   # Dropout parameter
                 dmodel,    # Model's dimension
                 nheads,    # Number of heads in the model
                 abspos):   # Is positional information encoded in the input ?
        super(MultiHeadAttention, self).__init__()
        self.Wo = nn.Linear(dmodel, dmodel, bias=False)
        self.Wk = nn.Linear(dmodel, dmodel, bias=False)
        self.Wq = nn.Linear(dmodel, dmodel, bias=False)
        self.Wv = nn.Linear(dmodel, dmodel, bias=False)
        self.nheads = nheads
        self.dk = dmodel // nheads
        self.Att = ScaledDotProductAttention(sizeSeq, self.dk,
                                             dropout, not abspos)

    def trans_(self, x):
        bsz, bptt, h, dk = x.size(0), x.size(1), self.nheads, self.dk
        return x.view(bsz, bptt, h, dk).transpose(1, 2).contiguous().view(bsz * h, bptt, dk)

    def reverse_trans_(self, x):
        bsz, bptt, h, dk = x.size(
            0) // self.nheads, x.size(1), self.nheads, self.dk
        return x.view(bsz, h, bptt, dk).transpose(1, 2).contiguous().view(bsz, bptt, h * dk)

    def forward(self, Q, K, V):
        q = self.trans_(self.Wq(Q))
        k = self.trans_(self.Wk(K))
        v = self.trans_(self.Wv(V))
        y = self.reverse_trans_(self.Att(q, k, v))
        return self.Wo(y)


class FFNetwork(nn.Module):
    def __init__(self, din, dout, dff, dropout):
        super(FFNetwork, self).__init__()
        self.lin1 = nn.Linear(din, dff, bias=True)
        self.lin2 = nn.Linear(dff, dout, bias=True)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        return self.lin2(self.drop(self.relu(self.lin1(x))))


class TransformerLayer(nn.Module):
    def __init__(self, sizeSeq=32, dmodel=512, dff=2048,
                 dropout=0.1, nheads=8,
                 abspos=False):
        super(TransformerLayer, self).__init__()
        self.multihead = MultiHeadAttention(sizeSeq, dropout,
                                            dmodel, nheads, abspos)
        self.ln_multihead = nn.LayerNorm(dmodel)
        self.ffnetwork = FFNetwork(dmodel, dmodel, dff, dropout)
        self.ln_ffnetwork = nn.LayerNorm(dmodel)

    def forward(self, x):
        y = self.ln_multihead(x + self.multihead(Q=x, K=x, V=x))
        return self.ln_ffnetwork(y + self.ffnetwork(y))


class StaticPositionEmbedding(nn.Module):
    def __init__(self, seqlen, dmodel):
        super(StaticPositionEmbedding, self).__init__()
        pos = torch.arange(0., seqlen).unsqueeze(1).repeat(1, dmodel)
        dim = torch.arange(0., dmodel).unsqueeze(0).repeat(seqlen, 1)
        div = torch.exp(- math.log(10000) * (2*(dim//2)/dmodel))
        pos *= div
        pos[:, 0::2] = torch.sin(pos[:, 0::2])
        pos[:, 1::2] = torch.cos(pos[:, 1::2])
        self.register_buffer('pe', pos.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


def buildTransformerAR(dimEncoded,    # Output dimension of the encoder
                       nLayers,       # Number of transformer layers
                       sizeSeq,       # Expected size of the input sequence
                       abspos):
    layerSequence = []
    if abspos:
        layerSequence += [StaticPositionEmbedding(sizeSeq, dimEncoded)]
    layerSequence += [TransformerLayer(sizeSeq=sizeSeq,
                                       dmodel=dimEncoded, abspos=abspos)
                      for i in range(nLayers)]
    return nn.Sequential(*layerSequence)

# Trainer

In [13]:
import numpy as np
import torch
import time
from copy import deepcopy
# import matplotlib.pyplot as plt
import json


def update_logs(logs, logStep, prevlogs=None):
    out = {}
    for key in logs:
        out[key] = deepcopy(logs[key])

        if prevlogs is not None:
            out[key] -= prevlogs[key]
        out[key] /= logStep
    return out


def save_logs(data, pathLogs):
    with open(pathLogs, 'w') as file:
        json.dump(data, file, indent=2)


def save_checkpoint(model_state, criterion_state, optimizer_state, best_state,
                    path_checkpoint):

    state_dict = {"gEncoder": model_state,
                  "cpcCriterion": criterion_state,
                  "optimizer": optimizer_state,
                  "best": best_state}

    torch.save(state_dict, path_checkpoint)


def show_logs(text, logs):
    print("")
    print('-' * 50)
    print(text)

    for key in logs:

        if key == "iter":
            continue

        nPredicts = logs[key].shape[0]

        strSteps = ['Step'] + [str(s) for s in range(1, nPredicts + 1)]
        formatCommand = ' '.join(['{:>16}' for _ in range(nPredicts + 1)])
        print(formatCommand.format(*strSteps))

        strLog = [key] + ["{:10.6f}".format(s) for s in logs[key]]
        print(formatCommand.format(*strLog))

    print('-' * 50)


def trainStep(dataLoader,
              cpcModel,
              cpcCriterion,
              optimizer,
              loggingStep,
              useGPU,
              log2Board=0,
              totalSteps=0):
    cpcModel.train()
    cpcCriterion.train()

    startTime = time.perf_counter()
    n_examples = 0
    logs, lastlogs = {}, None
    iterCtr = 0

    if log2Board > 1:
        gradmapGEncoder = {}
        gradmapGAR = {}
        gradmapWPrediction = {}
        if totalSteps == 0:
            logWeights(cpcModel.gEncoder, totalSteps)
            logWeights(cpcModel.gAR, totalSteps)
            logWeights(cpcCriterion.wPrediction, totalSteps)

    for step, fulldata in enumerate(dataLoader):
        batchData, label = fulldata
        n_examples += batchData.size(0)
        if useGPU:
            batchData = batchData.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)
        c_feature, encoded_data, label = cpcModel(batchData, label)
        allLosses, allAcc = cpcCriterion(c_feature, encoded_data)
        totLoss = allLosses.sum()

        totLoss.backward()

        if log2Board > 1:
            gradmapGEncoder = updateGradientMap(cpcModel.gEncoder, gradmapGEncoder)
            gradmapGAR = updateGradientMap(cpcModel.gAR, gradmapGAR)
            gradmapWPrediction = updateGradientMap(cpcCriterion.wPrediction, gradmapWPrediction)

        optimizer.step()
        optimizer.zero_grad()

        if "locLoss_train" not in logs:
            logs["locLoss_train"] = np.zeros(allLosses.size(1))
            logs["locAcc_train"] = np.zeros(allLosses.size(1))

        logs["locLoss_train"] += (allLosses.mean(dim=0)).detach().cpu().numpy()
        logs["locAcc_train"] += (allAcc.mean(dim=0)).cpu().numpy()
        iterCtr += 1

        if log2Board:
            for t in range(len(logs["locLoss_train"])):
                experiment.log_metric(f"Losses/batch/locLoss_train_{t}", logs["locLoss_train"][t] / iterCtr, step=totalSteps + iterCtr)
                experiment.log_metric(f"Accuracy/batch/locAcc_train_{t}", logs["locAcc_train"][t] / iterCtr, step=totalSteps + iterCtr)

        if (step + 1) % loggingStep == 0:
            new_time = time.perf_counter()
            elapsed = new_time - startTime
            print(f"Update {step + 1}")
            print(f"elapsed: {elapsed:.1f} s")
            print(
                f"{1000.0 * elapsed / loggingStep:.1f} ms per batch, {1000.0 * elapsed / n_examples:.1f} ms / example")
            locLogs = update_logs(logs, loggingStep, lastlogs)
            lastlogs = deepcopy(logs)
            show_logs("Training loss", locLogs)
            startTime, n_examples = new_time, 0
            
            if log2Board > 1:
                # Log gradients
                logGradients(gradmapGEncoder, totalSteps + iterCtr, scaleBy=1.0 / iterCtr)
                logGradients(gradmapGAR, totalSteps + iterCtr, scaleBy=1.0 / iterCtr)
                logGradients(gradmapWPrediction, totalSteps + iterCtr, scaleBy=1.0 / iterCtr)
                # Log weights
                logWeights(cpcModel.gEncoder, totalSteps + iterCtr)
                logWeights(cpcModel.gAR, totalSteps + iterCtr)
                logWeights(cpcCriterion.wPrediction, totalSteps + iterCtr)

    logs = update_logs(logs, iterCtr)
    logs["iter"] = iterCtr
    show_logs("Average training loss on epoch", logs)
    return logs


def valStep(dataLoader,
            cpcModel,
            cpcCriterion,
            useGPU):
    cpcCriterion.eval()
    cpcModel.eval()
    logs = {}
    cpcCriterion.eval()
    cpcModel.eval()
    iterCtr = 0

    for step, fulldata in enumerate(dataLoader):

        batchData, label = fulldata

        if useGPU:
            batchData = batchData.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)

        with torch.no_grad():
            c_feature, encoded_data, label = cpcModel(batchData, label)
            allLosses, allAcc = cpcCriterion(c_feature, encoded_data)

        if "locLoss_val" not in logs:
            logs["locLoss_val"] = np.zeros(allLosses.size(1))
            logs["locAcc_val"] = np.zeros(allLosses.size(1))

        iterCtr += 1
        logs["locLoss_val"] += allLosses.mean(dim=0).cpu().numpy()
        logs["locAcc_val"] += allAcc.mean(dim=0).cpu().numpy()

    logs = update_logs(logs, iterCtr)
    logs["iter"] = iterCtr
    show_logs("Validation loss:", logs)
    return logs


def updateGradientMap(model, gradMap):
    for name, param in model.named_parameters():
        paramName = name.split('.')
        paramLabel = paramName[-1]
        if paramLabel not in ['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 
                              'lowHz_', 'bandHz_', 'weight', 'bias']:
            continue
        param = model
        for i in range(len(paramName)):
            param = getattr(param, paramName[i])
        gradMap.setdefault("%s/%s" % ("Gradients", name), 0)
        gradMap["%s/%s" % ("Gradients", name)] += param.grad
    return gradMap


def logGradients(gradMap, step, scaleBy=1.0):
    for k, v in gradMap.items():
        experiment.log_histogram_3d(v.cpu().detach().numpy() * scaleBy, name=k, step=step)


def logWeights(model, step):
    for name, param in model.named_parameters():
        paramName = name.split('.')
        paramLabel = paramName[-1]
        if paramLabel not in ['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 
                              'lowHz_', 'bandHz_', 'weight', 'bias']:
            continue
        param = model
        for i in range(len(paramName)):
            param = getattr(param, paramName[i])
        experiment.log_histogram_3d(param.cpu().detach().numpy(), name="%s/%s" % ("Parameters", name), step=step)


def trainingLoop(trainDataset,
                 valDataset,
                 batchSize,
                 samplingMode,
                 cpcModel,
                 cpcCriterion,
                 nEpoch,
                 optimizer,
                 pathCheckpoint,
                 logs,
                 useGPU,
                 log2Board=0):
    print(f"Running {nEpoch} epochs")
    startEpoch = len(logs["epoch"])
    bestAcc = 0
    bestStateDict = None
    startTime = time.time()
    epoch = 0
    totalSteps = 0
    try:
        for epoch in range(startEpoch, nEpoch):
            print(f"Starting epoch {epoch}")
            trainLoader = trainDataset.getDataLoader(batchSize, samplingMode,
                                                    True, numWorkers=0)
            valLoader = valDataset.getDataLoader(batchSize, 'sequential', False,
                                                numWorkers=0)

            print("Training dataset %d batches, Validation dataset %d batches, batch size %d" %
                (len(trainLoader), len(valLoader), batchSize))

            locLogsTrain = trainStep(trainLoader, cpcModel, cpcCriterion, optimizer, logs["logging_step"], 
                                        useGPU, log2Board, totalSteps)

            totalSteps += locLogsTrain['iter']

            locLogsVal = valStep(valLoader, cpcModel, cpcCriterion, useGPU)

            print(f'Ran {epoch + 1} epochs '
                f'in {time.time() - startTime:.2f} seconds')

            if useGPU:
                torch.cuda.empty_cache()

            currentAccuracy = float(locLogsVal["locAcc_val"].mean())
            
            if log2Board:
                for t in range(len(locLogsVal["locLoss_val"])):
                    experiment.log_metric(f"Losses/epoch/locLoss_train_{t}", locLogsTrain["locLoss_train"][t] / totalSteps, step=epoch)
                    experiment.log_metric(f"Accuracy/epoch/locAcc_train_{t}", locLogsTrain["locAcc_train"][t] / totalSteps, step=epoch)
                    experiment.log_metric(f"Losses/epoch/locLoss_val_{t}", locLogsVal["locLoss_val"][t] / totalSteps, step=epoch)
                    experiment.log_metric(f"Accuracy/epoch/locAcc_val_{t}", locLogsVal["locAcc_val"][t] / totalSteps, step=epoch)

            if currentAccuracy > bestAcc:
                bestStateDict = cpcModel.state_dict()

            for key, value in dict(locLogsTrain, **locLogsVal).items():
                if key not in logs:
                    logs[key] = [None for _ in range(epoch)]
                if isinstance(value, np.ndarray):
                    value = value.tolist()
                logs[key].append(value)

            logs["epoch"].append(epoch)

            if pathCheckpoint is not None and (epoch % logs["saveStep"] == 0 or epoch == nEpoch - 1):
                modelStateDict = cpcModel.state_dict()
                criterionStateDict = cpcCriterion.state_dict()

                save_checkpoint(modelStateDict, criterionStateDict, optimizer.state_dict(), bestStateDict,
                                f"{pathCheckpoint}_{epoch}.pt")
                save_logs(logs, pathCheckpoint + "_logs.json")
    except KeyboardInterrupt:
        if pathCheckpoint is not None:
            modelStateDict = cpcModel.state_dict()
            criterionStateDict = cpcCriterion.state_dict()

            save_checkpoint(modelStateDict, criterionStateDict, optimizer.state_dict(), bestStateDict,
                            f"{pathCheckpoint}_{epoch}_interrupted.pt")
            save_logs(logs, pathCheckpoint + "_logs.json")
        return

def run(trainDataset,
        valDataset,
        batchSize,
        samplingMode,
        cpcModel,
        cpcCriterion,
        nEpoch,
        optimizer,
        pathCheckpoint,
        logs,
        useGPU,
        log2Board=0):
    if log2Board:
        with experiment.train():
            trainingLoop(trainDataset, valDataset, batchSize, samplingMode, cpcModel, cpcCriterion, nEpoch, optimizer,
                         pathCheckpoint, logs, useGPU, log2Board)
            experiment.end()
    else:
        trainingLoop(trainDataset, valDataset, batchSize, samplingMode, cpcModel, cpcCriterion, nEpoch, optimizer, 
                     pathCheckpoint, logs, useGPU, log2Board)

# Main

In [9]:
import torch
# from dataloader import AudioBatchData
# from model import CPCEncoder, CPCAR, CPCModel, CPCUnsupersivedCriterion
# from trainer import run
from datetime import datetime
import os

In [10]:
useGPU = torch.cuda.is_available()
labelsBy = 'ensemble'
print("Loading the training dataset")
trainDataset = AudioBatchData(rawAudioPath='data/musicnet_lousy/train_data',
                                metadataPath='data/musicnet_lousy/metadata_train.csv',
                                sizeWindow=20480,
                                labelsBy=labelsBy,
                                outputPath='data/musicnet_lousy/train_data/train',
                                CHUNK_SIZE=1e9,
                                NUM_CHUNKS_INMEM=7,
                                useGPU=useGPU)
print("Training dataset loaded")
print("")

print("Loading the validation dataset")
valDataset = AudioBatchData(rawAudioPath='data/musicnet_lousy/train_data',
                            metadataPath='data/musicnet_lousy/metadata_val.csv',
                            sizeWindow=20480,
                            labelsBy=labelsBy,
                            outputPath='data/musicnet_lousy/train_data/val',
                            CHUNK_SIZE=1e9,
                            NUM_CHUNKS_INMEM=1,
                            useGPU=False)
print("Validation dataset loaded")
print("")

Loading the training dataset
Chunks already exist at data/musicnet_lousy/train_data/train/ensemble
Loading files
Loaded 288 sequences, elapsed=143.175 secs
Training dataset loaded

Loading the validation dataset
Chunks already exist at data/musicnet_lousy/train_data/val/ensemble
Loading files
Loaded 32 sequences, elapsed=20.344 secs
Validation dataset loaded



In [None]:
useTransformer = True
samplingType = 'samesequence'

# Encoder network
encoderNet = CPCEncoder(512, 'layerNorm', sincNet=True)
# AR Network
if useTransformer:
    arNet = buildTransformerAR(512, 1, 20480 // 160, abspos=False)
    hiddenGAr = 512
else:
    arNet = CPCAR(512, 256, samplingType == 'sequential', 1, mode="GRU", reverse=False)
    hiddenGAr = 256

cpcModel = CPCModel(encoderNet, arNet)
batchSize = 8
cpcModel.supervised = False

cpcCriterion = CPCUnsupersivedCriterion(nPredicts=12,
                                        dimOutputAR=hiddenGAr,
                                        dimOutputEncoder=512,
                                        negativeSamplingExt=128,
                                        mode=None,
                                        dropout=False)

if useGPU:
    cpcCriterion.cuda()
    cpcModel.cuda()

gParams = list(cpcCriterion.parameters()) + list(cpcModel.parameters())
lr = 2e-4
optimizer = torch.optim.Adam(gParams, lr=lr, betas=(0.9, 0.999), eps=1e-8)

expDescription = f'{samplingType}_'
if samplingType == 'samecategory':
    expDescription += f'{labelsBy}_'

pathCheckpoint = f'logs/{expDescription}{datetime.now().strftime("%d-%m_%H-%M-%S")}'
os.makedirs(pathCheckpoint, exist_ok=True)
pathCheckpoint = os.path.join(pathCheckpoint, "checkpoint")

logs = {"epoch": [], "iter": [], "saveStep": 1, "logging_step": 1000}

run(trainDataset, valDataset, batchSize, samplingType, cpcModel, cpcCriterion, 30, optimizer, pathCheckpoint, logs, 
    useGPU, log2Board=2)

Running 30 epochs
Starting epoch 0
Training dataset 10681 batches, Validation dataset 1172 batches, batch size 8
Update 1000
elapsed: 308.9 s
308.9 ms per batch, 38.6 ms / example

--------------------------------------------------
Training loss
            Step                1                2                3                4                5                6                7                8                9               10               11               12
   locLoss_train         2.309381         3.001758         4.006115         4.108611         4.169734         4.198469         4.168210         4.172213         4.228028         4.306034         4.381202         4.461448
            Step                1                2                3                4                5                6                7                8                9               10               11               12
    locAcc_train         0.631744         0.328276         0.136016         0.124933       

In [12]:
# gru = nn.GRU(512, 256, num_layers=1, batch_first=True)
# sizeHidden = 512
# sinc0 = SincConv1D(1, sizeHidden, 10, stride=5, padding=3)
# conv0 = nn.Conv1d(1, sizeHidden, 10, stride=5, padding=3)
# batchNorm0 = ChannelNorm(sizeHidden)
# reluLayer = nn.LeakyReLU(0.1)
# print("GRU")
# for name, param in gru.named_parameters():
#     print("\t", name)
#     print("\t", name.split('.')[-1])
# print("Sinc")
# for name, param in sinc0.named_parameters():
#     print("\t", name)
#     print("\t", name.split('.')[-1])
# print("Conv")
# for name, param in conv0.named_parameters():
#     print("\t", name)
#     print("\t", name.split('.')[-1])
# print("BatchNorm")
# for name, param in batchNorm0.named_parameters():
#     print("\t", name)
#     print("\t", name.split('.')[-1])
# print("Transformer")
# for name, param in arNet.named_parameters():
#     print("\t", name)
#     print("\t", name.split('.')[-1])
# print("ReLU")
# for name, param in reluLayer.named_parameters():
#     print("\t", name)
#     print("\t", name.split('.')[-1])
# print(len(list(reluLayer.named_parameters())))