In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import torch as T
from tqdm import tqdm
from tqdm import trange
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
import torch.nn as nn
import re
import numpy as np
import os
import json
import random
import gc
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
'''for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))'''

filepath = '/kaggle/input/ml2022spring-hw4/Dataset'
segment_len = 512
train_ratio = 0.8
batch_size = 512
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
class voxData(Dataset):
    def __init__(self, path, train=True, seg_len=128):
        with open(os.path.join(path, 'mapping.json'), 'r') as a:
            self.mapping = json.load(a)["speaker2id"]
        self.data = []
        self.seg_len = seg_len
        self.path = path
        if train == True:
            self.train = True
            with open(os.path.join(path, 'metadata.json'), 'r') as a:
                self.meta = json.load(a)['speakers']
            for i in self.meta.keys():
                ids = self.mapping[i]
                self.data.extend([[j['feature_path'], ids] for j in self.meta[i]])

        else:
            self.train = False
            with open(os.path.join(path, 'testdata.json'), 'r') as a:
                self.meta = json.load(a)['utterances']
            for i in self.meta['utterances']:
                self.data.append([i['feature_path']])
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        if self.train:
            file, spk = self.data[idx]
            spk = T.FloatTensor([spk]).long()
        else:
            file = self.data[idx]
        pt = T.load(os.path.join(self.path, file))
        if pt.shape[0] >= self.seg_len:
            start = random.randint(0, pt.shape[0]-self.seg_len)
            if self.train:
                return pt[start:start+self.seg_len], spk
            else:
                return pt[start:start+self.seg_len]
        else:
            if self.train:
                return pt, spk
            else:
                return pt
        return pt
    def num_spks(self):
        return len(self.meta.keys())

In [3]:
class Vxclf(nn.Module):
    def __init__(self, firstlayer = 40, numspks = 600):
        super().__init__()
        self.prenet = nn.Linear(40, firstlayer)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=firstlayer, dim_feedforward=256, nhead=1, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
        self.outnet = nn.Sequential(
            nn.Linear(firstlayer, firstlayer),
            nn.ReLU(),
            nn.Linear(firstlayer, numspks)
        )
    def forward(self, x):
        out = self.prenet(x)
        #print('prenet output', out.shape)
        out = self.encoder_layer(out)
        #print('encoder output', out.shape)
        #out = out.transpose(0, 1)
        #print('transpose output', out.shape)
        # mean pooling
        stats = out.mean(dim=1)
        #print('mean output', stats.shape)
        # out: (batch, n_spks)
        out = self.outnet(stats)
        #print('outnet output', out.shape)
        return out

In [4]:
vxd = voxData(filepath, seg_len=segment_len)
#print(a.meta)
nspks = vxd.num_spks()
train_len = int(train_ratio*len(vxd))

traindata, validdata = random_split(vxd, [train_len, len(vxd)-train_len])
chklabel = set()
for _, j in traindata:
    chklabel.add(j.item())
    if len(chklabel) == nspks:
        break
assert(len(chklabel)==nspks),"resample training data"




In [5]:
def batch_fn(batch):
    data, label = zip(*batch)
    data = pad_sequence(data, batch_first=True, padding_value=1e-20)
    return data, T.FloatTensor(label).long()
trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=True, 
                pin_memory=True, collate_fn=batch_fn)
valdloader = DataLoader(validdata, batch_size=batch_size,
                pin_memory=True, collate_fn=batch_fn)

In [6]:
device = T.device("cuda" if T.cuda.is_available() else "cpu")
clf = Vxclf(firstlayer=100).to(device)
optimizer = T.optim.AdamW(clf.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for num in trange(15):
    clf.train()
    for x, y in trainloader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        temp = clf(x)
        
        loss = criterion(temp, y.long().squeeze())
        out = temp.argmax(1)
        accuracy = T.mean((out == y).float())
        loss.backward()
        optimizer.step()
        #print(accuracy)
    vaccuracy = []
    clf.eval()
    for m, n in valdloader:
        m = m.to(device)
        n = n.to(device)
        with T.no_grad():
            temp = clf(m)
            out = temp.argmax(1)
            vaccuracy.append(T.mean((out == n).float()).to('cpu').item())
    print('Average validate accuracy', np.mean(vaccuracy))

  7%|▋         | 1/15 [06:42<1:33:58, 402.72s/it]

Average validate accuracy 0.04613742236371921


 13%|█▎        | 2/15 [07:34<42:29, 196.14s/it]  

Average validate accuracy 0.15501504294250323


 20%|██        | 3/15 [08:25<25:58, 129.88s/it]

Average validate accuracy 0.23585258165131445


 27%|██▋       | 4/15 [09:16<18:08, 98.92s/it] 

Average validate accuracy 0.28976611080376996


 33%|███▎      | 5/15 [10:08<13:39, 81.90s/it]

Average validate accuracy 0.34043090110239776


 40%|████      | 6/15 [11:00<10:46, 71.87s/it]

Average validate accuracy 0.3882302991721941


 47%|████▋     | 7/15 [11:53<08:45, 65.70s/it]

Average validate accuracy 0.4183229816996533


 53%|█████▎    | 8/15 [12:46<07:10, 61.55s/it]

Average validate accuracy 0.43372476748798205


 60%|██████    | 9/15 [13:39<05:53, 58.93s/it]

Average validate accuracy 0.4628421003403871


 67%|██████▋   | 10/15 [14:32<04:45, 57.12s/it]

Average validate accuracy 0.4801654699056045


 73%|███████▎  | 11/15 [15:24<03:42, 55.58s/it]

Average validate accuracy 0.5037994954897009


 80%|████████  | 12/15 [16:17<02:44, 54.83s/it]

Average validate accuracy 0.5245511454084645


 87%|████████▋ | 13/15 [17:10<01:48, 54.27s/it]

Average validate accuracy 0.5352144798506862


 93%|█████████▎| 14/15 [18:03<00:53, 53.78s/it]

Average validate accuracy 0.5486315994159036


100%|██████████| 15/15 [18:56<00:00, 75.76s/it]

Average validate accuracy 0.5615440607070923





In [7]:
vaccuracy 

[0.578125,
 0.58984375,
 0.564453125,
 0.544921875,
 0.556640625,
 0.564453125,
 0.59375,
 0.529296875,
 0.580078125,
 0.5625,
 0.587890625,
 0.56640625,
 0.5546875,
 0.509765625,
 0.60546875,
 0.5859375,
 0.513671875,
 0.591796875,
 0.580078125,
 0.5859375,
 0.560546875,
 0.56640625,
 0.44285714626312256]

In [8]:
pwd

'/kaggle/working'