# 任务描述
- 对特定特征的说话人进行分类。
- 主要目标：学习如何使用transformer。

In [15]:
import os
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset ,DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

## Dataset DataLorder
- Original dataset is [Voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/).

### Dataset

In [12]:
class myDataset(Dataset):
  def __init__(self, data_dir, segment_len=128):
    self.data_dir  ,self.segment_len=data_dir ,segment_len 
    
    speaker2id = json.load(open(data_dir+"/mapping.json"))["speaker2id"]    
    metadata = json.load(open(data_dir+"/metadata.json"))["speakers"] 
    self.speaker_num = len(metadata.keys())
    self.data = []
    for speaker in metadata.keys():
      for utterances in metadata[speaker]:
        self.data.append([utterances["feature_path"], speaker2id[speaker]])
 
  def __len__(self):
    return len(self.data)
 
  def __getitem__(self, index):
    feat_path, speaker = self.data[index]
    # Load preprocessed mel-spectrogram.
    mel = torch.load(os.path.join(self.data_dir, feat_path))
    #最大长度限制在segment_len
    if len(mel) > self.segment_len:      
      start = random.randint(0, len(mel) - self.segment_len)      
      mel = torch.FloatTensor(mel[start:start+self.segment_len])
    else:
      mel = torch.FloatTensor(mel)    
    speaker = torch.FloatTensor([speaker]).long()
    return mel, speaker
 
  def get_speaker_number(self):
    return self.speaker_num

### Datalorder
- 将数据集分为训练数据集（90%）和验证数据集（10%）

In [48]:
def collate_batch(batch):
  # Process features within a batch.
  """Collate a batch of data."""
  print(len(batch))
  mel, speaker = zip(*batch)
  print(mel.shape)
  # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.
  mel = pad_sequence(mel, batch_first=True, padding_value=-20)    # pad log 10^(-20) which is very small value.
  # mel: (batch size, length, 40)
  return mel, torch.FloatTensor(speaker).long()

def get_dataloader(data_dir, batch_size, n_workers):
  """生成数据加载器"""
  dataset = myDataset(data_dir)
  speaker_num = dataset.get_speaker_number()
  # 将数据集拆分为训练数据集和验证数据集
  trainlen = int(0.9 * len(dataset))  
  trainset, validset = random_split(dataset, [trainlen, len(dataset) - trainlen])

  train_loader = DataLoader(trainset, batch_size=batch_size,  shuffle=True,  drop_last=True,
    num_workers=n_workers,  pin_memory=True,   collate_fn=collate_batch,  )
  valid_loader = DataLoader( validset,  batch_size=batch_size,  num_workers=n_workers,
    drop_last=True,  pin_memory=True,  collate_fn=collate_batch, )

  return train_loader, valid_loader, speaker_num

In [37]:
aa = myDataset('../../data/Dataset')

In [50]:
bb = get_dataloader('../../data/Dataset',20,4)

In [36]:
for i in bb:
    len(i)
    break

torch.Size([3, 25, 300])