In [1]:
import kaldiio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from sklearn.metrics import roc_auc_score
import random

In [2]:
class FHSData(torch.utils.data.Dataset):
    def __init__(self, feats_scp, utt2label, cmvn_file=None):
        super().__init__()
        self.uttids = []
        self.utt2feats = {}
        self.utt2labs = {}
        with open(feats_scp) as f:
            for line in f:
                splits = line.rstrip().split()
                self.uttids.append(splits[0])
                self.utt2feats[splits[0]] = splits[1]
        
        with open(utt2label) as f:
            for line in f:
                splits = line.rstrip().split()
                self.utt2labs[splits[0]] = int(splits[1])
        
        if cmvn_file is not None:
            self.cmvn = self._load_cmvn(cmvn_file)
        else:
            self.cmvn = None
    
    def __len__(self):
        return len(self.uttids)
    
    def __getitem__(self, idx):
        uttid = self.uttids[idx]
        feats_path = self.utt2feats[uttid]
        features = kaldiio.load_mat(feats_path).copy()
        if self.cmvn is not None:
            features = self._apply_cmvn(features, self.cmvn)
        label = self.utt2labs[uttid]
        return {'uttid': uttid, 'feats': torch.tensor(features).float(),
               'target': torch.tensor(label).long()}
    
    def _load_cmvn(self, cmvn_file):
        cmvn = kaldiio.load_mat(cmvn_file)
        assert cmvn.shape[0] == 2
        cnt = cmvn[0, -1]
        sums = cmvn[0, :-1]
        sums2 = cmvn[1, :-1]
        means = sums / cnt
        stds = np.sqrt(np.maximum(1e-10, sums2 / cnt - means ** 2))
        return means, stds
    
    def _apply_cmvn(self, features, cmvn):
        # https://github.com/kaldi-asr/kaldi/blob/master/src/transform/cmvn.cc
        means, stds = cmvn
        features -= means
        features /= stds
        return features

def _collate_fn(batch):
    # max_len = max(len(ex['feats']) for ex in batch)
    max_len = 1024
    batch_feats = torch.zeros(len(batch), max_len, batch[0]['feats'].shape[-1])
    batch_targets = []
    batch_lens = []
    for i, ex in enumerate(batch):
        feats = ex['feats']
        tgt = ex['target']
        batch_feats[i, :len(feats)] = feats
        batch_lens.append(len(feats))
        batch_targets.append(tgt)
    
    batch_targets = torch.stack(batch_targets, dim=0).long()
    batch_lens = torch.tensor(batch_lens).long()
    return {'feats': batch_feats,'feats_len': batch_lens, 'targets': batch_targets}

In [3]:
tr_wav_scp = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/train_fra/feats.scp'
tr_utt2label = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/train_fra/utt2label'
cmvn_file = '/data/sls/scratch/sameerk/fhs_prepared/kaldi_data/fbanks/cmvn.ark'
tr_ds = FHSData(tr_wav_scp, tr_utt2label, cmvn_file)
tr_ds = torch.utils.data.DataLoader(tr_ds, batch_size=32, collate_fn=_collate_fn, drop_last=False, 
                                    num_workers=2, shuffle=True)

dev_wav_scp = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/dev_fra/feats.scp'
dev_utt2label = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/dev_fra/utt2label'
cmvn_file = '/data/sls/scratch/sameerk/fhs_prepared/kaldi_data/fbanks/cmvn.ark'
dev_ds = FHSData(dev_wav_scp, dev_utt2label, cmvn_file)
dev_ds = torch.utils.data.DataLoader(dev_ds, batch_size=32, collate_fn=_collate_fn, drop_last=False, 
                                    num_workers=2, shuffle=False)

In [None]:
for ex in tr_ds:
    print(ex['feats'].shape)
    print(ex['targets'])
    print(ex['feats_len'])
    break

for ex in dev_ds:
    print(ex['feats'].shape)
    print(ex['targets'])
    print(ex['feats_len'])
    break

In [8]:
class CNN(nn.Module):
    def __init__(self, num_classes=1):
        super(CNN, self).__init__()

        self.num_classes = num_classes

        # defining batchnorm input                                                                                               
        # self.batchnorm1 = nn.BatchNorm2d(1)

        # defining Convolutional layers                                                                                           
        conv1_C = 256
        conv2_C = 512
        conv3_C = 1024
        conv4_C = 2048
        # conv5_C = hparams['out_channels'][4]
        conv_W = 18
        pad = int(np.floor(conv_W / 2) - 1)
        stride_W = 2
        pool_W = 3
        self.conv1 = nn.Conv2d(1, conv1_C, kernel_size=(40,1), stride=(1,1), padding=(0,0))
        self.bn1 = nn.BatchNorm2d(conv1_C)
        self.conv2 = nn.Conv2d(conv1_C, conv2_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn2 = nn.BatchNorm2d(conv2_C)
        self.conv3 = nn.Conv2d(conv2_C, conv3_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn3 = nn.BatchNorm2d(conv3_C)
        self.conv4 = nn.Conv2d(conv3_C, conv4_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn4 = nn.BatchNorm2d(conv4_C)
        # self.conv5 = nn.Conv2d(conv4_C, conv5_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))

        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))
        
        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))

        # Defining output layer                                                                                                   
        input_W = 2 ** 11
        num_layers = 3
        conv_C = 2048
        
        if num_layers == 0:
            num_layers = 1
        # hard coded, figure it out manually
        map_W = 255                                                                                  
        embedding_dim = int(conv_C)
        self.fc1 = nn.Linear(embedding_dim, num_classes)
        self.fc2 = nn.Sigmoid()
        
        # Defining global average pooling                                                                                         
        # self.poolMean = nn.AvgPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))

        # Defining global max pooling                                                                                             
        self.poolMax = nn.MaxPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))
    
    def forward(self, x):
        x = x.transpose(1, 2)
        targets = batch['targets']
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        # x = self.batchnorm1(x)
        x = self.bn1(F.relu(self.conv1(x)))
        x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        x = self.pool(x)
        x = self.bn4(F.relu(self.conv4(x)))
        x = self.poolMax(x)
        x = x.squeeze(2)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        
    
#### Model without FC layer #####
class CNNNoFC(nn.Module):
    def __init__(self, num_classes=1):
        super(CNN, self).__init__()

        self.num_classes = num_classes

        # defining batchnorm input                                                                                               
        # self.batchnorm1 = nn.BatchNorm2d(1)

        # defining Convolutional layers                                                                                           
        conv1_C = 256
        conv2_C = 512
        conv3_C = 1024
        conv4_C = 2048
        # conv5_C = hparams['out_channels'][4]
        conv_W = 18
        pad = int(np.floor(conv_W / 2) - 1)
        stride_W = 2
        pool_W = 3
        self.conv1 = nn.Conv2d(1, conv1_C, kernel_size=(40,1), stride=(1,1), padding=(0,0))
        self.bn1 = nn.BatchNorm2d(conv1_C)
        self.conv2 = nn.Conv2d(conv1_C, conv2_C, kernel_size=(1, conv_W), stride=(1,1), padding=(0,pad))
        self.bn2 = nn.BatchNorm2d(conv2_C)
        self.conv3 = nn.Conv2d(conv2_C, conv3_C, kernel_size=(1, conv_W), stride=(1,1), padding=(0,pad))
        self.bn3 = nn.BatchNorm2d(conv3_C)
        self.conv4 = nn.Conv2d(conv3_C, conv4_C, kernel_size=(1, conv_W), stride=(1,1), padding=(0,pad))
        self.bn4 = nn.BatchNorm2d(conv4_C)
        # self.conv5 = nn.Conv2d(conv4_C, conv5_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))

        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))
        
        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))

        # Defining output layer                                                                                                   
        input_W = 2 ** 11
        num_layers = 3
        conv_C = 2048
        
        if num_layers == 0:
            num_layers = 1
        # hard coded, figure it out manually
        map_W = 255                                                                                  
        embedding_dim = int(conv_C)
        # self.fc1 = nn.Linear(embedding_dim, num_classes)
        self.fc2 = nn.Sigmoid()
        
        # Defining global average pooling                                                                                         
        # self.poolMean = nn.AvgPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))

        # Defining global max pooling                                                                                             
        self.poolMax = nn.MaxPool2d(kernel_size=(1, map_W), stride=(1,map_W), padding=(0,0))
    
    def forward(self, x):
        x = x.transpose(1, 2)
        targets = batch['targets']
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        # x = self.batchnorm1(x)
        x = self.bn1(F.relu(self.conv1(x)))
        x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        x = self.pool(x)
        x = self.bn4(F.relu(self.conv4(x)))
        x = self.poolMax(x)
        # print(x.shape)
        x = x.squeeze(2)
        x = x.reshape(x.size(0), -1)
        # x = self.fc1(x)
        # print("++", x.shape)
        x = torch.mean(x, dim=1)[:, None]
        x = self.fc2(x)
        return x

In [None]:
cnn_org = CNN()

In [None]:
cnn_org

In [None]:
batch = next(iter(tr_ds))

In [None]:
x = cnn_org(batch['feats'])
print(x.shape)

In [None]:
device = 'cuda'

for wd in [1e-7, 1e-5, 1e-3]:
    for lr in [0.01, 0.005, 0.001, 0.0005]:
        for seed in [1111, 2222, 1234, 3333]:
            print("++++++ Exp seed %d wd %s lr %f ++++++++" % (seed, str(wd), lr))
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            model = CNN()
            model.cuda()
            optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

            auc_best = 0.
            for epoch in range(5):
                model.train()
                epoch_loss = 0.
                y_pred = []
                y_true = []
                for batch in tqdm.tqdm(tr_ds, total=len(tr_ds)):
                    x = batch['feats']
                    # use lens to do average pooling properly
                    x_lens = batch['feats_len']
                    tgts = batch['targets']
                    x = x.cuda()
                    tgts = tgts.cuda()
                    x = model(x).squeeze(1)
                    loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    epoch_loss += float(loss)
                    y_pred.append(x.detach().cpu().numpy())
                    y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                print('epoch loss: %f' % (epoch_loss/len(tr_ds)))
                print('auc: %f' % auc)

                epoch_loss = 0.
                y_pred = []
                y_true = []
                model.eval()
                with torch.no_grad():
                    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
                        x = batch['feats']
                        # use lens to do average pooling properly
                        x_lens = batch['feats_len']
                        tgts = batch['targets']
                        x = x.cuda()
                        tgts = tgts.cuda()
                        x = model(x).squeeze(1)
                        loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                        epoch_loss += float(loss)
                        y_pred.append(x.detach().cpu().numpy())
                        y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                if auc > auc_best:
                    print("saving ckpt at auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    torch.save(model.state_dict(), "auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    auc_best = auc

                print('epoch dev loss: %f' % (epoch_loss/len(dev_ds)))
                print('auc dev: %f' % auc)
    

In [None]:
print(auc)

In [None]:
from sklearn.metrics import classification_report
y_pred_class = (np.concatenate(y_pred)>0.5).astype(float)
print(classification_report(np.concatenate(y_true), y_pred_class))

In [None]:
class CNNHalf(nn.Module):
    def __init__(self, num_classes=1):
        super(CNNHalf, self).__init__()

        self.num_classes = num_classes

        # defining batchnorm input                                                                                               
        # self.batchnorm1 = nn.BatchNorm2d(1)

        # defining Convolutional layers                                                                                           
        conv1_C = 256
        # conv2_C = 512
        conv3_C = 512
        # conv4_C = 2048
        # conv5_C = hparams['out_channels'][4]
        conv_W = 18
        pad = int(np.floor(conv_W / 2) - 1)
        stride_W = 2
        pool_W = 3
        self.conv1 = nn.Conv2d(1, conv1_C, kernel_size=(40,1), stride=(1,1), padding=(0,0))
        self.bn1 = nn.BatchNorm2d(conv1_C)
        # self.conv2 = nn.Conv2d(conv1_C, conv2_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn2 = nn.BatchNorm2d(conv2_C)
        self.conv3 = nn.Conv2d(conv1_C, conv3_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn3 = nn.BatchNorm2d(conv3_C)
        # self.conv4 = nn.Conv2d(conv3_C, conv4_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn4 = nn.BatchNorm2d(conv4_C)
        # self.conv5 = nn.Conv2d(conv4_C, conv5_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))

        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))
        
        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))

        # Defining output layer                                                                                                   
        input_W = 2 ** 11
        num_layers = 3
        conv_C = 512
        
        if num_layers == 0:
            num_layers = 1
        # hard coded, figure it out manually
        map_W = 511                                                                                  
        embedding_dim = int(conv_C)
        self.fc1 = nn.Linear(embedding_dim, num_classes)
        self.fc2 = nn.Sigmoid()
        
        # Defining global average pooling                                                                                         
        # self.poolMean = nn.AvgPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))

        # Defining global max pooling                                                                                             
        self.poolMax = nn.MaxPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))
    
    def forward(self, x):
        x = x.transpose(1, 2)
        targets = batch['targets']
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        # x = self.batchnorm1(x)
        x = self.bn1(F.relu(self.conv1(x)))
        # x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        # x = self.pool(x)
        # x = self.bn4(F.relu(self.conv4(x)))
        x = self.poolMax(x)
        x = x.squeeze(2)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        

In [None]:
cnn_half = CNNHalf()
cnn_half

In [None]:
cnn_half = CNNHalf()
batch = next(iter(tr_ds))
x = cnn_half(batch['feats'])
print(x.shape)

### No KD small model

In [None]:
device = 'cuda'

for wd in [1e-7]:
    for lr in [0.01]:
        for seed in [10001]:
            print("++++++ Exp seed %d wd %s lr %f ++++++++" % (seed, str(wd), lr))
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            model = CNNHalf()
            model.cuda()
            optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

            auc_best = 0.
            for epoch in range(5):
                model.train()
                epoch_loss = 0.
                y_pred = []
                y_true = []
                for batch in tqdm.tqdm(tr_ds, total=len(tr_ds)):
                    x = batch['feats']
                    # use lens to do average pooling properly
                    x_lens = batch['feats_len']
                    tgts = batch['targets']
                    x = x.cuda()
                    tgts = tgts.cuda()
                    x = model(x).squeeze(1)
                    loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    epoch_loss += float(loss)
                    y_pred.append(x.detach().cpu().numpy())
                    y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                print('epoch loss: %f' % (epoch_loss/len(tr_ds)))
                print('auc: %f' % auc)

                epoch_loss = 0.
                y_pred = []
                y_true = []
                model.eval()
                with torch.no_grad():
                    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
                        x = batch['feats']
                        # use lens to do average pooling properly
                        x_lens = batch['feats_len']
                        tgts = batch['targets']
                        x = x.cuda()
                        tgts = tgts.cuda()
                        x = model(x).squeeze(1)
                        loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                        epoch_loss += float(loss)
                        y_pred.append(x.detach().cpu().numpy())
                        y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                if auc > auc_best:
                    print("saving ckpt at auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    torch.save(model.state_dict(), "auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    auc_best = auc

                print('epoch dev loss: %f' % (epoch_loss/len(dev_ds)))
                print('auc dev: %f' % auc)
    

In [None]:
y_pred = []
y_true = []
model = CNNHalf()
model.load_state_dict(torch.load('auc_0.874646_seed_3333_wd_1e-07_lr_0.010000.pt'))
model.eval()
model.cuda()
with torch.no_grad():
    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
        x = batch['feats']
        # use lens to do average pooling properly
        x_lens = batch['feats_len']
        tgts = batch['targets']
        x = x.cuda()
        tgts = tgts.cuda()
        x = model(x).squeeze(1)
        y_pred.append(x.detach().cpu().numpy())
        y_true.append(tgts.cpu().numpy())

auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))
print(auc)

In [None]:
#Note that in binary classification, recall of the positive class is also known as “sensitivity”; recall of the negative class is “specificity”.
from sklearn.metrics import classification_report
y_pred_class = (np.concatenate(y_pred)>0.5).astype(float)
print(classification_report(np.concatenate(y_true), y_pred_class))

In [4]:
import torchvision

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x

class EffNetOri(nn.Module):
    def __init__(self, label_dim=1, pretrain=False, model_id=0, audioset_pretrain=False):
        super().__init__()
        b = int(model_id)
        print('now train a effnet-b{:d} model'.format(b))
        if b == 7:
            self.model = torchvision.models.efficientnet_b7(pretrained=pretrain)
        elif b == 6:
            self.model = torchvision.models.efficientnet_b6(pretrained=pretrain)
        elif b == 5:
            self.model = torchvision.models.efficientnet_b5(pretrained=pretrain)
        elif b == 4:
            self.model = torchvision.models.efficientnet_b4(pretrained=pretrain)
        elif b == 3:
            self.model = torchvision.models.efficientnet_b3(pretrained=pretrain)
        elif b == 2:
            self.model = torchvision.models.efficientnet_b2(pretrained=pretrain)
        elif b == 1:
            self.model = torchvision.models.efficientnet_b1(pretrained=pretrain)
        elif b == 0:
            self.model = torchvision.models.efficientnet_b0(pretrained=pretrain)
            self.model.load_state_dict(torch.load('efficientnet_b0_rwightman-3dd342df.pth'))
        new_proj = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        print('conv1 get from pretrained model.')
        new_proj.weight = torch.nn.Parameter(torch.sum(self.model.features[0][0].weight, dim=1).unsqueeze(1))
        new_proj.bias = self.model.features[0][0].bias
        self.model.features[0][0] = new_proj
        # self.model = create_feature_extractor(self.model, {'features.8': 'mout'})
        self.model.avgpool = Identity()
        self.model.classifier = Identity()
        # print(self.model)
        self.feat_dim, self.freq_dim = self.get_dim()
        print(self.feat_dim)
        self.attention = MeanPooling(self.feat_dim, label_dim)

    def get_dim(self):
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = torch.zeros(10, 1000, 40)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)
        # print(x.shape)
        x = self.model.features(x)
        # print(x.shape[1], x.shape[2])
        return int(x.shape[1]), int(x.shape[2])

    def forward(self, x):
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)
        x = self.model.features(x)
        out = torch.sigmoid(self.attention(x))
        return out

class MeanPooling(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.layernorm = nn.LayerNorm(n_in)
        self.linear = nn.Linear(n_in, n_out)
        print('use mean pooling')

    def forward(self, x):
        """input: (samples_num, freq_bins, time_steps, 1)
        """
        x = torch.mean(x, dim=[2, 3])
        x = self.linear(x)
        return x

In [5]:
eff_net = EffNetOri()
batch = next(iter(tr_ds))
eff_net(batch['feats']).shape

now train a effnet-b0 model
conv1 get from pretrained model.
1280
use mean pooling


torch.Size([32, 1])

In [None]:
tr_wav_scp = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/train_fra/feats.scp'
tr_utt2label = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/train_fra/utt2label'
cmvn_file = '/data/sls/scratch/sameerk/fhs_prepared/kaldi_data/fbanks/cmvn.ark'
tr_ds = FHSData(tr_wav_scp, tr_utt2label, cmvn_file)
tr_ds = torch.utils.data.DataLoader(tr_ds, batch_size=32, collate_fn=_collate_fn, drop_last=False, 
                                    num_workers=2, shuffle=True)

dev_wav_scp = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/dev_fra/feats.scp'
dev_utt2label = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/dev_fra/utt2label'
cmvn_file = '/data/sls/scratch/sameerk/fhs_prepared/kaldi_data/fbanks/cmvn.ark'
dev_ds = FHSData(dev_wav_scp, dev_utt2label, cmvn_file)
dev_ds = torch.utils.data.DataLoader(dev_ds, batch_size=32, collate_fn=_collate_fn, drop_last=False, 
                                    num_workers=2, shuffle=False)


test_wav_scp = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/test_fra/feats.scp'
test_utt2label = '/data/sls/u/sameerk/code/kaldi/egs/librispeech/s5/data/test_fra/utt2label'
cmvn_file = '/data/sls/scratch/sameerk/fhs_prepared/kaldi_data/fbanks/cmvn.ark'
test_ds = FHSData(test_wav_scp, test_utt2label, cmvn_file)
test_ds = torch.utils.data.DataLoader(test_ds, batch_size=128, collate_fn=_collate_fn, drop_last=False, 
                                      num_workers=2, shuffle=False)


In [None]:
device = 'cuda'

for wd in [1e-7, 1e-5, 1e-3]:
    for lr in [0.01, 0.005, 0.08, 0.03, 0.001]:
        for seed in [3333, 1111, 6825, 2222, 1234, 5555, 10001, 345667]:
            print("++++++ Exp seed %d wd %s lr %f ++++++++" % (seed, str(wd), lr))
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            model = EffNetOri()
            model.cuda()
            optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

            auc_best = 0.
            for epoch in range(5):
                model.train()
                epoch_loss = 0.
                y_pred = []
                y_true = []
                for batch in tqdm.tqdm(tr_ds, total=len(tr_ds)):
                    x = batch['feats']
                    # use lens to do average pooling properly
                    x_lens = batch['feats_len']
                    tgts = batch['targets']
                    x = x.cuda()
                    tgts = tgts.cuda()
                    x = model(x).squeeze(1)
                    loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    epoch_loss += float(loss)
                    y_pred.append(x.detach().cpu().numpy())
                    y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                print('epoch loss: %f' % (epoch_loss/len(tr_ds)))
                print('auc: %f' % auc)

                epoch_loss = 0.
                y_pred = []
                y_true = []
                model.eval()
                with torch.no_grad():
                    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
                        x = batch['feats']
                        # use lens to do average pooling properly
                        x_lens = batch['feats_len']
                        tgts = batch['targets']
                        x = x.cuda()
                        tgts = tgts.cuda()
                        x = model(x).squeeze(1)
                        loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                        epoch_loss += float(loss)
                        y_pred.append(x.detach().cpu().numpy())
                        y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                if auc > auc_best:
                    print("saving ckpt at auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    torch.save(model.state_dict(), "auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    auc_best = auc

                print('epoch dev loss: %f' % (epoch_loss/len(dev_ds)))
                print('auc dev: %f' % auc)
    

In [None]:
from sklearn.metrics import classification_report
import os
    
y_pred = []
y_true = []
model = EffNetOri()
model.load_state_dict(torch.load('auc_0.911797_seed_10001_wd_1e-07_lr_0.030000.pt'))
print(m)
model.eval()
model.cuda()
with torch.no_grad():
    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
        x = batch['feats']
        # use lens to do average pooling properly
        x_lens = batch['feats_len']
        tgts = batch['targets']
        x = x.cuda()
        tgts = tgts.cuda()
        x = model(x).squeeze(1)
        y_pred.append(x.detach().cpu().numpy())
        y_true.append(tgts.cpu().numpy())

auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))
print(auc)
y_pred_class = (np.concatenate(y_pred)>0.5).astype(float)
print(classification_report(np.concatenate(y_true), y_pred_class))

In [None]:
#Note that in binary classification, recall of the positive class is also known as “sensitivity”; recall of the negative class is “specificity”.
from sklearn.metrics import classification_report
y_pred_class = (np.concatenate(y_pred)>0.5).astype(float)
print(classification_report(np.concatenate(y_true), y_pred_class))

### KD with small model

#### load big models for KD

In [None]:
eff_net = EffNetOri()
big_cnn = CNN()
eff_net.load_state_dict(torch.load('models_effnet/auc_0.911797_seed_10001_wd_1e-07_lr_0.030000.pt'))
big_cnn.load_state_dict(torch.load('models/auc_0.852169_seed_2222_wd_1e-07_lr_0.010000.pt'))

#### perform KD

In [7]:
class CNNSmallV1(nn.Module):
    def __init__(self, num_classes=1):
        super(CNNSmallV1, self).__init__()

        self.num_classes = num_classes

        # defining batchnorm input                                                                                               
        # self.batchnorm1 = nn.BatchNorm2d(1)

        # defining Convolutional layers                                                                                           
        conv1_C = 256
        # conv2_C = 512
        conv3_C = 1024
        # conv4_C = 2048
        # conv5_C = hparams['out_channels'][4]
        conv_W = 18
        pad = int(np.floor(conv_W / 2) - 1)
        stride_W = 2
        pool_W = 3
        self.conv1 = nn.Conv2d(1, conv1_C, kernel_size=(40,1), stride=(1,1), padding=(0,0))
        self.bn1 = nn.BatchNorm2d(conv1_C)
        # self.conv2 = nn.Conv2d(conv1_C, conv2_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn2 = nn.BatchNorm2d(conv2_C)
        self.conv3 = nn.Conv2d(conv1_C, conv3_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn3 = nn.BatchNorm2d(conv3_C)
        # self.conv4 = nn.Conv2d(conv3_C, conv4_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn4 = nn.BatchNorm2d(conv4_C)
        # self.conv5 = nn.Conv2d(conv4_C, conv5_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))

        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))
        
        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))

        # Defining output layer                                                                                                   
        input_W = 2 ** 11
        num_layers = 3
        conv_C = 1024
        
        if num_layers == 0:
            num_layers = 1
        # hard coded, figure it out manually
        map_W = 511                                                                                  
        embedding_dim = int(conv_C)
        self.fc1 = nn.Linear(embedding_dim, num_classes)
        self.fc2 = nn.Sigmoid()
        
        # Defining global average pooling                                                                                         
        # self.poolMean = nn.AvgPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))

        # Defining global max pooling                                                                                             
        self.poolMax = nn.MaxPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))
    
    def forward(self, x):
        x = x.transpose(1, 2)
        targets = batch['targets']
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        # x = self.batchnorm1(x)
        x = self.bn1(F.relu(self.conv1(x)))
        # x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        # x = self.pool(x)
        # x = self.bn4(F.relu(self.conv4(x)))
        x = self.poolMax(x)
        x = x.squeeze(2)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        

In [None]:
device = 'cuda'
big_cnn.eval()
big_cnn.cuda()
eff_net.eval()
eff_net.cuda()

for wd in [1e-7, 1e-5, 1e-3]:
    for lr in [0.01, 0.005, 0.08, 0.03, 0.001]:
        for seed in [3333, 1111, 6825, 2222, 1234, 5555, 10001, 345667]:
            print("++++++ Exp seed %d wd %s lr %f ++++++++" % (seed, str(wd), lr))
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            model = CNNSmallV1()
            model.cuda()
            optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

            auc_best = 0.
            for epoch in range(5):
                model.train()
                epoch_loss = 0.
                y_pred = []
                y_true = []
                for batch in tqdm.tqdm(tr_ds, total=len(tr_ds)):
                    x = batch['feats']
                    # use lens to do average pooling properly
                    x_lens = batch['feats_len']
                    tgts = batch['targets']
                    x = x.cuda()
                    tgts = tgts.cuda()
                    with torch.no_grad():
                        p_y1 = eff_net(x).squeeze(1)
                        p_y2 = big_cnn(x).squeeze(1)
                    x = model(x).squeeze(1)
                    loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                    loss += F.binary_cross_entropy(x, p_y1, reduction='mean')
                    loss += F.binary_cross_entropy(x, p_y2, reduction='mean')
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    epoch_loss += float(loss)
                    y_pred.append(x.detach().cpu().numpy())
                    y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                print('epoch loss: %f' % (epoch_loss/len(tr_ds)))
                print('auc: %f' % auc)

                epoch_loss = 0.
                y_pred = []
                y_true = []
                model.eval()
                with torch.no_grad():
                    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
                        x = batch['feats']
                        # use lens to do average pooling properly
                        x_lens = batch['feats_len']
                        tgts = batch['targets']
                        x = x.cuda()
                        tgts = tgts.cuda()
                        x = model(x).squeeze(1)
                        loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                        epoch_loss += float(loss)
                        y_pred.append(x.detach().cpu().numpy())
                        y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                if auc > auc_best:
                    print("saving ckpt at auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    torch.save(model.state_dict(), "auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    auc_best = auc

                print('epoch dev loss: %f' % (epoch_loss/len(dev_ds)))
                print('auc dev: %f' % auc)
    

In [14]:
class CNNSmallV2(nn.Module):
    def __init__(self, num_classes=1):
        super(CNNSmallV2, self).__init__()

        self.num_classes = num_classes

        # defining batchnorm input                                                                                               
        # self.batchnorm1 = nn.BatchNorm2d(1)

        # defining Convolutional layers                                                                                           
        conv1_C = 128
        # conv2_C = 512
        conv3_C = 256
        # conv4_C = 2048
        # conv5_C = hparams['out_channels'][4]
        conv_W = 18
        pad = int(np.floor(conv_W / 2) - 1)
        stride_W = 2
        pool_W = 3
        self.conv1 = nn.Conv2d(1, conv1_C, kernel_size=(40,1), stride=(1,1), padding=(0,0))
        self.bn1 = nn.BatchNorm2d(conv1_C)
        # self.conv2 = nn.Conv2d(conv1_C, conv2_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn2 = nn.BatchNorm2d(conv2_C)
        self.conv3 = nn.Conv2d(conv1_C, conv3_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn3 = nn.BatchNorm2d(conv3_C)
        # self.conv4 = nn.Conv2d(conv3_C, conv4_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn4 = nn.BatchNorm2d(conv4_C)
        # self.conv5 = nn.Conv2d(conv4_C, conv5_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))

        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))
        
        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))

        # Defining output layer                                                                                                   
        input_W = 2 ** 11
        num_layers = 3
        conv_C = 256
        
        if num_layers == 0:
            num_layers = 1
        # hard coded, figure it out manually
        map_W = 511                                                                                  
        embedding_dim = int(conv_C)
        self.fc1 = nn.Linear(embedding_dim, num_classes)
        self.fc2 = nn.Sigmoid()
        
        # Defining global average pooling                                                                                         
        # self.poolMean = nn.AvgPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))

        # Defining global max pooling                                                                                             
        self.poolMax = nn.MaxPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))
    
    def forward(self, x):
        x = x.transpose(1, 2)
        targets = batch['targets']
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        # x = self.batchnorm1(x)
        x = self.bn1(F.relu(self.conv1(x)))
        # x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        # x = self.pool(x)
        # print(x.shape)
        # x = self.bn4(F.relu(self.conv4(x)))
        x = self.poolMax(x)
        x = x.squeeze(2)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        

In [None]:
device = 'cuda'
big_cnn.eval()
big_cnn.cuda()
eff_net.eval()
eff_net.cuda()

for wd in [1e-7, 1e-5, 1e-3]:
    for lr in [0.01, 0.005, 0.08, 0.03, 0.001]:
        for seed in [3333, 1111, 6825, 2222, 1234, 5555, 10001, 345667]:
            print("++++++ Exp seed %d wd %s lr %f ++++++++" % (seed, str(wd), lr))
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            model = CNNSmallV2()
            model.cuda()
            optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

            auc_best = 0.
            for epoch in range(5):
                model.train()
                epoch_loss = 0.
                y_pred = []
                y_true = []
                for batch in tqdm.tqdm(tr_ds, total=len(tr_ds)):
                    x = batch['feats']
                    # use lens to do average pooling properly
                    x_lens = batch['feats_len']
                    tgts = batch['targets']
                    x = x.cuda()
                    tgts = tgts.cuda()
                    with torch.no_grad():
                        p_y1 = eff_net(x).squeeze(1)
                        p_y2 = big_cnn(x).squeeze(1)
                    x = model(x).squeeze(1)
                    loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                    loss += F.binary_cross_entropy(x, p_y1, reduction='mean')
                    loss += F.binary_cross_entropy(x, p_y2, reduction='mean')
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    epoch_loss += float(loss)
                    y_pred.append(x.detach().cpu().numpy())
                    y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                print('epoch loss: %f' % (epoch_loss/len(tr_ds)))
                print('auc: %f' % auc)

                epoch_loss = 0.
                y_pred = []
                y_true = []
                model.eval()
                with torch.no_grad():
                    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
                        x = batch['feats']
                        # use lens to do average pooling properly
                        x_lens = batch['feats_len']
                        tgts = batch['targets']
                        x = x.cuda()
                        tgts = tgts.cuda()
                        x = model(x).squeeze(1)
                        loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                        epoch_loss += float(loss)
                        y_pred.append(x.detach().cpu().numpy())
                        y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                if auc > auc_best:
                    print("saving ckpt at auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    torch.save(model.state_dict(), "auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    auc_best = auc

                print('epoch dev loss: %f' % (epoch_loss/len(dev_ds)))
                print('auc dev: %f' % auc)
    

In [16]:
class CNNSmallV3(nn.Module):
    def __init__(self, num_classes=1):
        super(CNNSmallV3, self).__init__()

        self.num_classes = num_classes

        # defining batchnorm input                                                                                               
        # self.batchnorm1 = nn.BatchNorm2d(1)

        # defining Convolutional layers                                                                                           
        conv1_C = 64
        # conv2_C = 512
        conv3_C = 128
        # conv4_C = 2048
        # conv5_C = hparams['out_channels'][4]
        conv_W = 18
        pad = int(np.floor(conv_W / 2) - 1)
        stride_W = 2
        pool_W = 3
        self.conv1 = nn.Conv2d(1, conv1_C, kernel_size=(40,1), stride=(1,1), padding=(0,0))
        self.bn1 = nn.BatchNorm2d(conv1_C)
        # self.conv2 = nn.Conv2d(conv1_C, conv2_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn2 = nn.BatchNorm2d(conv2_C)
        self.conv3 = nn.Conv2d(conv1_C, conv3_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        self.bn3 = nn.BatchNorm2d(conv3_C)
        # self.conv4 = nn.Conv2d(conv3_C, conv4_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))
        # self.bn4 = nn.BatchNorm2d(conv4_C)
        # self.conv5 = nn.Conv2d(conv4_C, conv5_C, kernel_size=(1,conv_W), stride=(1,1), padding=(0,pad))

        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))
        
        # Defining pooling                                                                                                        
        pad = int(np.floor(pool_W / 2))
        self.pool = nn.MaxPool2d(kernel_size=(1,pool_W), stride=(1,stride_W),padding=(0,pad))

        # Defining output layer                                                                                                   
        input_W = 2 ** 11
        num_layers = 3
        conv_C = 128
        
        if num_layers == 0:
            num_layers = 1
        # hard coded, figure it out manually
        map_W = 511                                                                                  
        embedding_dim = int(conv_C)
        self.fc1 = nn.Linear(embedding_dim, num_classes)
        self.fc2 = nn.Sigmoid()
        
        # Defining global average pooling                                                                                         
        # self.poolMean = nn.AvgPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))

        # Defining global max pooling                                                                                             
        self.poolMax = nn.MaxPool2d(kernel_size=(1,map_W), stride=(1,map_W), padding=(0,0))
    
    def forward(self, x):
        x = x.transpose(1, 2)
        targets = batch['targets']
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        # x = self.batchnorm1(x)
        x = self.bn1(F.relu(self.conv1(x)))
        # x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        # x = self.pool(x)
        # print(x.shape)
        # x = self.bn4(F.relu(self.conv4(x)))
        x = self.poolMax(x)
        x = x.squeeze(2)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        

In [None]:
device = 'cuda'
big_cnn.eval()
big_cnn.cuda()
eff_net.eval()
eff_net.cuda()

for wd in [1e-7, 1e-5, 1e-3]:
    for lr in [0.01, 0.005, 0.08, 0.03, 0.001]:
        for seed in [3333, 1111, 6825, 2222, 1234, 5555, 10001, 345667]:
            print("++++++ Exp seed %d wd %s lr %f ++++++++" % (seed, str(wd), lr))
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            model = CNNSmallV3()
            model.cuda()
            optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

            auc_best = 0.
            for epoch in range(5):
                model.train()
                epoch_loss = 0.
                y_pred = []
                y_true = []
                for batch in tqdm.tqdm(tr_ds, total=len(tr_ds)):
                    x = batch['feats']
                    # use lens to do average pooling properly
                    x_lens = batch['feats_len']
                    tgts = batch['targets']
                    x = x.cuda()
                    tgts = tgts.cuda()
                    with torch.no_grad():
                        p_y1 = eff_net(x).squeeze(1)
                        p_y2 = big_cnn(x).squeeze(1)
                    x = model(x).squeeze(1)
                    loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                    loss += F.binary_cross_entropy(x, p_y1, reduction='mean')
                    loss += F.binary_cross_entropy(x, p_y2, reduction='mean')
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    epoch_loss += float(loss)
                    y_pred.append(x.detach().cpu().numpy())
                    y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                print('epoch loss: %f' % (epoch_loss/len(tr_ds)))
                print('auc: %f' % auc)

                epoch_loss = 0.
                y_pred = []
                y_true = []
                model.eval()
                with torch.no_grad():
                    for batch in tqdm.tqdm(dev_ds, total=len(dev_ds)):
                        x = batch['feats']
                        # use lens to do average pooling properly
                        x_lens = batch['feats_len']
                        tgts = batch['targets']
                        x = x.cuda()
                        tgts = tgts.cuda()
                        x = model(x).squeeze(1)
                        loss = F.binary_cross_entropy(x, tgts.float(), reduction='mean')
                        epoch_loss += float(loss)
                        y_pred.append(x.detach().cpu().numpy())
                        y_true.append(tgts.cpu().numpy())

                auc = roc_auc_score(np.concatenate(y_true), np.concatenate(y_pred))

                if auc > auc_best:
                    print("saving ckpt at auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    torch.save(model.state_dict(), "auc_%f_seed_%d_wd_%s_lr_%f.pt" % (auc, seed, str(wd), lr))
                    auc_best = auc

                print('epoch dev loss: %f' % (epoch_loss/len(dev_ds)))
                print('auc dev: %f' % auc)
    