In [1]:
import sys
import torch
import numpy as np  
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
#         if x.dim() == 2:
#             bias = self._bias.t().view(1, -1)
#         else:
#             bias = self._bias.t().view(1, -1, 1, 1)
        bias = self._bias.t().view(1, -1)
        return x + bias

#Categorical
class FixedCategorical(torch.distributions.Categorical):
    def sample(self):
        return super().sample().unsqueeze(-1)

    def log_probs(self, actions):
        return super().log_prob(actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1)

    def entropy(self):
        p = self.probs.masked_fill(self.probs <= 0, 1)
        return -1 * p.mul(p.log()).sum(-1)

    def mode(self):
        return self.probs.argmax(dim=-1, keepdim=True)

class Categorical(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Categorical, self).__init__()

        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=0.01
        )

        self.linear = init_(nn.Linear(num_inputs, num_outputs))

    def forward(self, x, mask=None,temperature=1):
        x = F.softmax(self.linear(x)/temperature,dim=-1)
        if mask is not None:
            return FixedCategorical(logits=x + torch.log(mask))
        else:
            return FixedCategorical(logits=x)

#Normal
class FixedNormal(torch.distributions.Normal):
    def log_probs(self, actions):
        return super().log_prob(actions).sum(-1, keepdim=True)

    def entropy(self):
        return super().entropy().sum(-1)

    def mode(self):
        return self.mean

class DiagGaussian(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(DiagGaussian, self).__init__()

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
#         desired_init_log_std = -0.693471 #exp(..) ~= 0.5
#         desired_init_log_std = -1.609437 #exp(..) ~=0.2
        desired_init_log_std = -2.302585 #exp(..) ~=0.1
        
        self.logstd = AddBias(desired_init_log_std * torch.ones(num_outputs)) #so no state-dependent sigma

    def forward(self, x, mask=None):
        action_mean = self.fc_mean(x)
#         print('action_mean',action_mean.shape,x.shape)
        zeros = torch.zeros(action_mean.size())
        if x.is_cuda:
            zeros = zeros.cuda()

        action_logstd = self.logstd(zeros)
        return FixedNormal(action_mean, action_logstd.exp())

class ActionHead(nn.Module):
    def __init__(self, input_dim, output_dim, type="categorical"):
        super(ActionHead, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.type = type
        if type == "categorical":
            self.distribution = Categorical(num_inputs=input_dim, num_outputs=output_dim)
        elif type == "normal":
            self.distribution = DiagGaussian(num_inputs=input_dim, num_outputs=output_dim)
        else:
            raise NotImplementedError

    def forward(self, input, mask):
        if self.type == "normal":
            return self.distribution(input)
        else:
            return self.distribution(input, mask)

class Pi_net(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Pi_net, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, int(hidden_size/2))
        self.action_heads = nn.ModuleList()
        self.action_heads.append(ActionHead(int(hidden_size/2), 2, type='categorical'))
        self.action_heads.append(ActionHead(int(hidden_size/2)+1, 1, type='normal'))
        
    def forward(self, s, deterministic=False):
        x = F.relu(self.linear1(s))
        x = F.relu(self.linear2(x))
        
        action_outputs=[]
        head_outputs=[]
        head_outputs.append(x)
        action_type_dist = self.action_heads[0](x,mask=None)
        if deterministic:
            action_type = action_type_dist.mode()
        else:
            action_type = action_type_dist.sample()
            
        head_outputs.append(action_type)
        action_outputs.append(action_type)
        head_output = torch.cat(head_outputs, dim=-1)

        head_dist = self.action_heads[1](head_output,mask=None)
        
        if deterministic:
            head_action = head_dist.mode()
        else:
            head_action = head_dist.rsample()
        
        action_outputs.append(head_action)

        joint_action_log_prob = action_type_dist.log_probs(action_type)
        entropy = action_type_dist.entropy().mean()
        
        joint_action_log_prob += head_dist.log_probs(head_action)

        entropy += head_dist.entropy().mean()
        action_outputs = torch.cat(action_outputs,dim=-1)
        return action_outputs, joint_action_log_prob, entropy
    
class Binary_deter_net(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Binary_deter_net, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, int(hidden_size/2))
        self.linear3 = nn.Linear(int(hidden_size/2),2)
        
    def forward(self, s):
        x = F.relu(self.linear1(s))
        x = F.relu(self.linear2(x))
        x = F.softmax(self.linear3(x),dim=-1)
        
        return x
    
class Multi_deter_net(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Multi_deter_net, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, int(hidden_size/2))
        self.linear3 = nn.Linear(int(hidden_size/2),2)
        self.linear4 = nn.Linear(int(hidden_size/2)+1,1)
        
    def forward(self, s):
        head_outputs=[]
        action_outputs=[]
        
        x = F.relu(self.linear1(s))
        x = F.relu(self.linear2(x))
        x_deter = self.linear3(x)
#         print(x_deter.argmax(-1).shape,s.shape)
        head_outputs.append(x)
#         head_outputs.append(x_deter.argmax(-1,keepdim=True))
        
        x_deter_soft = F.softmax(x_deter,dim=-1)
        head_outputs.append(x_deter_soft.argmax(-1,keepdim=True))
        
        head_output = torch.cat(head_outputs, dim=-1)
#         print(head_output.shape)
        value = self.linear4(head_output)
#         noise = torch.randn_like(value) * 0.1  # 标准差为0.1的高斯噪声
#         value = value + noise
        
        return x_deter_soft,value
    

## Dataset 

In [2]:
from torch.utils.data import Dataset
from collections import defaultdict
import torch, random, copy
from torch.utils.data import RandomSampler, SequentialSampler
from torch.utils.data import DataLoader
from scipy import spatial

class WavDataset(Dataset):
    def __init__(self, spk_list, spk2utt, embd_dict=None,embd_dim=128):
        self.spk_list = spk_list
        self.spk2utt = spk2utt
        self.embd_dict = embd_dict
        self.embd_dim = embd_dim

    def __len__(self):
        return len(self.spk2utt)
    
    def _extract_negative_sample(self, spk2utt, target_spk):
        # 从字典中剔除target_spk
        nontarget_spk_list = [spk for spk in spk2utt.keys() if spk != target_spk]
        # 随机选择一个nontarget_spk
        nontarget_spk = random.choice(nontarget_spk_list)
        # 从embd_dict[nontarget_spk]中随机选择一个value
        negative_sample = random.choice(spk2utt[nontarget_spk])

        return negative_sample
    
    def _extract_positive_sample(self, spk2utt, target_spk, enrol_utts):
        nonenroll_utt_list = [utt for utt in spk2utt[target_spk] if utt not in enrol_utts]
        positive_sample = random.choice(nonenroll_utt_list)
        return positive_sample

    def __getitem__(self, idx):

        spk=self.spk_list[idx]
        embds=[]
        label = random.sample([0,1],1)[0]
        num_enrol = random.sample([i for i in range(1,200)],1)[0]
        enrol_embd = np.zeros((1,self.embd_dim))
        if num_enrol >=len(self.spk2utt[spk]):
            enrol_utts = self.spk2utt[spk]
            num_enrol=len(self.spk2utt[spk])
            label =0
        else:
            enrol_utts = random.sample(self.spk2utt[spk],num_enrol)
            
        if num_enrol>5:
            for utt in enrol_utts[:5]:
                enrol_embd += self.embd_dict[utt]
            enrol_embd = enrol_embd/5
            for utt in enrol_utts[5:]:
                enrol_embd = (1-0.1)*enrol_embd + 0.1*self.embd_dict[utt]
                if random.random() < 0.5:
                    test_utt = self._extract_negative_sample(self.spk2utt,spk)
                    test_embd = self.embd_dict[test_utt]
                    result = 1 - spatial.distance.cosine(enrol_embd, test_embd)
                    if result>0.51:
                        enrol_embd = (1-0.1)*enrol_embd + 0.1*test_embd
        else:
            for utt in enrol_utts:
                enrol_embd += self.embd_dict[utt]
            enrol_embd = enrol_embd/len(enrol_utts)

        if label:
            test_utt = self._extract_positive_sample(self.spk2utt,spk,enrol_utts)
        else:
            test_utt = self._extract_negative_sample(self.spk2utt,spk)

        test_embd = self.embd_dict[test_utt]
        result = 1 - spatial.distance.cosine(enrol_embd, test_embd)

        embds_cat = np.concatenate((enrol_embd,test_embd),axis=1)
        embds_cat = torch.tensor(embds_cat, dtype=torch.float).squeeze(0)
        action_outputs=[]
        action_outputs.append(embds_cat)
        action_outputs = torch.cat(action_outputs,dim=-1)
        if label:
            return action_outputs,label,torch.tensor([0.1])
        else:
            return action_outputs,label,torch.tensor([0.1*result])
    
class WavBatchSampler(object):
    def __init__(self, dataset, shuffle=False, batch_size=1, drop_last=False):
        self.batch_size = batch_size
        self.drop_last = drop_last

        if shuffle:
            self.sampler = RandomSampler(dataset)
        else:
            self.sampler = SequentialSampler(dataset)

    def _renew(self):
        return []

    def __iter__(self):
        batch= self._renew()
        for idx in self.sampler:
            batch.append((idx))
            if len(batch) == self.batch_size:
                yield batch
                batch = self._renew()
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size
        
def acc_topk(x,y,topk=1):
    correct = 0
    maxk = max((1,topk))
    y_resize = y.view(-1,1)
    _, pred = x.topk(maxk, 1, True, True)
    correct += torch.eq(pred, y_resize).sum().float().item()
    return correct / len(y)

## Prepare embedding

In [5]:
import glob
embds_dicts={}
for npy_path in glob.glob('./DRL-TU/egs/embed/time_varying_all_T_epoch21_rank*.npy'):
    print(npy_path)
    embds_dict = np.load(npy_path,allow_pickle=True).item()
    embds_dicts ={**embds_dicts,**embds_dict}
    
for npy_path in glob.glob('./DRL-TU/egs/embed/vox2dev_filter_epoch21_rank*.npy'):
    print(npy_path)
    embds_dict = np.load(npy_path,allow_pickle=True).item()
    embds_dicts ={**embds_dicts,**embds_dict}

utt2spk = {i.split()[0]:i.split()[1] for i in open('./data/combine_smiiptv_vox2dev/utt2spk')}
spk2utt = {i.split()[0]:i.split()[1:] for i in open('./data/combine_smiiptv_vox2dev/spk2utt')}
spk_list = [i.split()[0] for i in open('./data/combine_smiiptv_vox2dev/spk2utt')]

utt2dur={i.split()[0]:i.split()[1] for i in open('./data/combine_smiiptv_vox2dev/dur.scp')}

spk2utt_new=defaultdict(list)
for spk in spk2utt:
    if spk[:2]=='id':
        spk2utt_new[spk]=spk2utt[spk]
    else:
        for utt in spk2utt[spk]:
            if float(utt2dur[utt])<1:
                continue
            else:
                spk2utt_new[spk].append(utt)

spk2utt = spk2utt_new

./DRL-TU/egs/embed/time_varying_all_T_epoch21_rank0.npy
./DRL-TU/egs/embed/time_varying_all_T_epoch21_rank1.npy
./DRL-TU/egs/embed/time_varying_all_T_epoch21_rank2.npy
./DRL-TU/egs/embed/vox2dev_filter_epoch21_rank0.npy
./DRL-TU/egs/embed/vox2dev_filter_epoch21_rank1.npy
./DRL-TU/egs/embed/vox2dev_filter_epoch21_rank2.npy
./DRL-TU/egs/embed/vox2dev_filter_epoch21_rank3.npy


# training model

In [6]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import sys 
from utils.utils import AverageMeter,get_lr

batch_size=512
dataset = WavDataset(spk_list, spk2utt=spk2utt, embd_dict=embds_dicts)
batch_sampler = WavBatchSampler(dataset, shuffle=True, batch_size=batch_size, drop_last=True)
train_loader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=12, pin_memory=True)

    
criterion_ce = nn.CrossEntropyLoss()
criterion_mse = nn.MSELoss()

# model_deter= Binary_deter_net(256,256).to('cuda:3')
model_deter= Multi_deter_net(256,256).to('cuda:3')

optimizer = torch.optim.SGD(list(model_deter.parameters()),
                            lr=0.1, momentum=0.95, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20,40,60], gamma=0.1, last_epoch=-1)

model_deter.train()


for epoch in range(0,80):
    losses, top1 = AverageMeter(), AverageMeter()
    for i,(data,label,label_v) in enumerate(train_loader):
    
        data,label,label_v = data.to('cuda:3'), label.long().to('cuda:3'), label_v.float().to('cuda:3')

        outputs,values = model_deter(data)
        loss1 = criterion_ce(outputs,label) # cross entropy loss
        loss2 = criterion_mse(values,label_v) # mse loss
        loss= loss1+loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1 = acc_topk(outputs.data, label)*100
        losses.update(loss.data.item(), batch_size)
        top1.update(prec1, batch_size)

        print('Epoch [%d][%d/%d]\t ' % (epoch, i+1, len(train_loader)) + 
              'Loss %.4f %.4f\t' % (losses.val, losses.avg) +
              'Top1 Acc %3.3f %3.3f\t' % (top1.val, top1.avg) +
              'LR %.6f\n' % get_lr(optimizer))

    scheduler.step()

Epoch [0][1/12]	 Loss 0.6975 0.6975	Top1 Acc 64.844 64.844	LR 0.100000

Epoch [0][2/12]	 Loss 0.6915 0.6945	Top1 Acc 65.820 65.332	LR 0.100000

Epoch [0][3/12]	 Loss 0.6842 0.6910	Top1 Acc 66.211 65.625	LR 0.100000

Epoch [0][4/12]	 Loss 0.6838 0.6892	Top1 Acc 64.844 65.430	LR 0.100000

Epoch [0][5/12]	 Loss 0.6834 0.6881	Top1 Acc 63.672 65.078	LR 0.100000

Epoch [0][6/12]	 Loss 0.6633 0.6839	Top1 Acc 69.141 65.755	LR 0.100000

Epoch [0][7/12]	 Loss 0.6547 0.6797	Top1 Acc 68.359 66.127	LR 0.100000

Epoch [0][8/12]	 Loss 0.6481 0.6758	Top1 Acc 67.969 66.357	LR 0.100000

Epoch [0][9/12]	 Loss 0.6594 0.6740	Top1 Acc 64.648 66.168	LR 0.100000

Epoch [0][10/12]	 Loss 0.6505 0.6716	Top1 Acc 66.602 66.211	LR 0.100000

Epoch [0][11/12]	 Loss 0.6463 0.6693	Top1 Acc 66.602 66.246	LR 0.100000

Epoch [0][12/12]	 Loss 0.6301 0.6661	Top1 Acc 68.359 66.423	LR 0.100000

Epoch [1][1/12]	 Loss 0.6313 0.6313	Top1 Acc 68.164 68.164	LR 0.100000

Epoch [1][2/12]	 Loss 0.6348 0.6330	Top1 Acc 67.969 68.066	LR

Epoch [9][9/12]	 Loss 0.5637 0.5889	Top1 Acc 69.922 67.969	LR 0.100000

Epoch [9][10/12]	 Loss 0.5911 0.5891	Top1 Acc 64.648 67.637	LR 0.100000

Epoch [9][11/12]	 Loss 0.5534 0.5859	Top1 Acc 71.484 67.987	LR 0.100000

Epoch [9][12/12]	 Loss 0.5852 0.5858	Top1 Acc 66.016 67.822	LR 0.100000

Epoch [10][1/12]	 Loss 0.5457 0.5457	Top1 Acc 69.922 69.922	LR 0.100000

Epoch [10][2/12]	 Loss 0.5492 0.5474	Top1 Acc 69.336 69.629	LR 0.100000

Epoch [10][3/12]	 Loss 0.5484 0.5478	Top1 Acc 69.922 69.727	LR 0.100000

Epoch [10][4/12]	 Loss 0.5526 0.5490	Top1 Acc 71.094 70.068	LR 0.100000

Epoch [10][5/12]	 Loss 0.5443 0.5480	Top1 Acc 71.289 70.312	LR 0.100000

Epoch [10][6/12]	 Loss 0.5333 0.5456	Top1 Acc 74.414 70.996	LR 0.100000

Epoch [10][7/12]	 Loss 0.5179 0.5416	Top1 Acc 78.906 72.126	LR 0.100000

Epoch [10][8/12]	 Loss 0.5313 0.5403	Top1 Acc 78.711 72.949	LR 0.100000

Epoch [10][9/12]	 Loss 0.5062 0.5365	Top1 Acc 83.594 74.132	LR 0.100000

Epoch [10][10/12]	 Loss 0.5126 0.5341	Top1 Acc 83.00

KeyboardInterrupt: 

In [7]:
torch.save({'model': model_deter.state_dict()},
           './DRL-TU/egs/exp/DRL-TU-MH_pretrained/pretrain_v1.pkl')