In [1]:
from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification, Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig
from transformers import default_data_collator
from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
import torch
import torch.nn as nn
from datasets import Dataset, Audio, Value, Features,load_dataset,ClassLabel
from transformers import Wav2Vec2Processor
from transformers.modeling_outputs import SequenceClassifierOutput
import numpy as np
from transformers import AdamW,get_scheduler
from datasets import load_metric
from tqdm.auto import tqdm
import os
from torch.utils.data import DataLoader
from datasets import load_metric


def featurize(batch):
#     audio_arrays = [batch['audio'][i]['array'] for i in range(len(batch))]
    audio_arrays = [batch['audio'][i]['array'] for i in range(len(batch['id']))]
#     print(len(audio_arrays))
    inputs = feature_extractor(
        audio_arrays, 
        sampling_rate=16_000, 
        max_length=int(16_000 * 10),  # 10s
        truncation=True, 
        padding='max_length',
    )
    return inputs



weighted_sum = False
checkpoint = "facebook/hubert-base-ls960"
# checkpoint = "facebook/wav2vec2-large-lv60"

x = [str(i) for i in range(0,100,1)]
features = Features(
    {
        "id": Value("string"),
        "speaker_id": Value("string"), 
        'path': Value('string'),
        "audio": Audio(sampling_rate=16000),
        "label": ClassLabel(num_classes=100,names=x,names_file=None,id=None)
    }
)


dataset = load_dataset('csv', 
                       data_files={'test':'../data/identification/test_100.csv'},
                       features=features)
dataset = dataset.map(remove_columns=(['path','speaker_id']),num_proc=24)
dataset = dataset.sort("label")
# sampling_rate = dataset.features["audio"].sampling_rate
if 'base' in checkpoint and not weighted_sum:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint,return_attention_mask=False)
else:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint,return_attention_mask=True)


dataset = dataset.map(featurize, remove_columns='audio',batched=True,num_proc=20,batch_size=1)

# dataset = dataset.rename_column('id','label_ids')

if 'base' in checkpoint and not weighted_sum:
    dataset.set_format("torch",columns=["id","input_values", "label"])
else:
    dataset.set_format("torch",columns=["id","input_values", "attention_mask", "label"])

  from .autonotebook import tqdm as notebook_tqdm
Using custom data configuration default-43c2fe400d73e135
Found cached dataset csv (/storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-43c2fe400d73e135/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 331.64it/s]
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-43c2fe400d73e135/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-2b43112609dacf4d.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-43c2fe400d73e135/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-ecfe8e2cea57772c.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-43c2fe400d73e135/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-07049e0558f

### Model

In [2]:
class CustomBaseSID(nn.Module):
    def __init__(self,checkpoint,num_labels,inter_layer_num,attend):
        
        ### attend is a boolean
        super(CustomBaseSID, self).__init__()
#         self.hubert = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid")
        self.model =AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=attend,output_hidden_states=True))
        self.num_labels = num_labels
        self.attend_mask = attend
        ### New layers:
        self.linear1 = nn.Linear(768, 1024)
        self.linear2 = nn.Linear(1024, num_labels)
        
        ### Intermediate Layer Number
        self.layer_num = inter_layer_num

    def forward(self, input_values=None, attention_mask=None,labels=None):
        if(self.attend_mask):
            outputs = self.model(input_values=input_values, attention_mask=attention_mask)
        else:
            outputs = self.model(input_values=input_values, attention_mask=None)
        feature = outputs.hidden_states[self.layer_num-1]
        agg_vec_list = []
        for i in range(len(feature)):
            if(attention_mask==None):
                length = len(feature[i])
            
            else:
                if torch.nonzero(attention_mask[i] < 0, as_tuple=False).size(0) == 0:
                    length = len(feature[i])
                else:
                    length = torch.nonzero(attention_mask[i] < 0, as_tuple=False)[0] + 1
            agg_vec=torch.mean(feature[i][:length], dim=0)
            agg_vec_list.append(agg_vec)
        mean = torch.stack(agg_vec_list)
        # sequence_output has the following shape: (batch_size, sequence_length, 768)
        linear1_output = self.linear1(mean) ## extract the 1st token's embeddings
        logits = self.linear2(linear1_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)


In [3]:
def freeze_layers_transformer(model_ft,keywords,inter_layer,encoder_mode=True):
    ct = 0
    for child in model_ft.children():
        for name,param in child.named_parameters():
            if(encoder_mode):
                if 'encoder' in name and str(ct) in name:
                    ct += 1
                    if ct < inter_layer:
                        param.requires_grad = False
                    
            for word in keywords:         
                if word in name:
                    param.requires_grad = False
                            
    return model_ft

In [4]:
dataset

DatasetDict({
    test: Dataset({
        features: ['id', 'label', 'input_values'],
        num_rows: 632
    })
})

In [5]:
def load_model(checkpoint, num_labels, inter_layer, model_dirpath, device):
    dirpath = model_dirpath + str(inter_layer)+'/'
    PATH = dirpath+str(4)+'.pt'
    if 'base' in checkpoint:
        attend = False
#     model_name = "model_ft_" + pretrained_model_name + str(inter_layer)
    model_ft = CustomBaseSID(checkpoint=checkpoint,num_labels=100,inter_layer_num=inter_layer,attend=attend)
    model_ft.load_state_dict(torch.load(PATH, map_location=device))

    return model_ft

In [6]:
metric_name="accuracy"
metric = load_metric(metric_name)
test_dataloader = DataLoader(dataset["test"], batch_size=1, collate_fn=default_data_collator)

  metric = load_metric(metric_name)


In [7]:
# keys = list(dataset['test'].features.keys())
keys = list(dataset['test']['id'])
test_dict = dict.fromkeys(keys)

In [8]:
list(test_dict.keys())[0]

'id10001-Y8hIVOBuels-00001'

In [9]:
def evaluate(test_dataloader,model_ft):
    test_dict = dict.fromkeys(list(dataset['test']['id']))
    keys = list(test_dict.keys())
    
    for i,batch in enumerate(test_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model_ft(**batch)
        logits = outputs.logits
        sm = torch.nn.Softmax(dim=1)
        probabilities = sm(logits)
        test_dict[keys[i]] = dict.fromkeys(['labels','probabilities','predictions'])
        test_dict[keys[i]]['labels'] = batch['labels'].detach().numpy()
        test_dict[keys[i]]['probabilities'] = probabilities[0].detach().numpy()
        predictions = torch.argmax(logits, dim=-1)
        test_dict[keys[i]]['predictions'] = predictions[0].detach().numpy()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    accuracy = metric.compute()
    return test_dict,accuracy

## hubertbase 12 layer

In [14]:
checkpoint_hubertbase = 'facebook/hubert-base-ls960'
model_dirpath = "/storage/home/hcocice1/vkotra3/6254_Project/code/hubert-base/"
device = "cpu"
model_ft_hubertbase_12 = load_model(checkpoint_hubertbase, 100, 12, model_dirpath, device)

hubertbase_12,hubertbase_12_accuracy = evaluate(test_dataloader,model_ft_hubertbase_12)

## hubertbase 9 layer

In [19]:
checkpoint_hubertbase = 'facebook/hubert-base-ls960'
model_dirpath = "/storage/home/hcocice1/vkotra3/6254_Project/code/hubert-base/"
device = "cpu"
model_ft_hubertbase_9 = load_model(checkpoint_hubertbase, 100, 9, model_dirpath, device)

hubertbase_9,hubertbase_9_accuracy = evaluate(test_dataloader,model_ft_hubertbase_9)

## hubertbase 6 layer

In [15]:
checkpoint_hubertbase = 'facebook/hubert-base-ls960'
model_dirpath = "/storage/home/hcocice1/vkotra3/6254_Project/code/hubert-base/"
device = "cpu"
model_ft_hubertbase_6 = load_model(checkpoint_hubertbase, 100, 6, model_dirpath, device)

hubertbase_6,hubertbase_6_accuracy = evaluate(test_dataloader,model_ft_hubertbase_6)

## hubertbase results

### 12 layer

In [27]:
hubertbase_12

{'id10001-Y8hIVOBuels-00001': {'labels': array([0]),
  'probabilities': array([3.94152164e-01, 4.84396696e-05, 2.61282752e-04, 6.72647788e-04,
         1.38084926e-02, 3.24053777e-04, 7.41351323e-05, 1.01796350e-05,
         4.01328243e-02, 1.19098579e-03, 5.48485725e-04, 1.22963165e-05,
         2.25877942e-04, 1.65788351e-05, 6.58127305e-04, 2.21315750e-05,
         1.69755840e-05, 2.46690906e-04, 5.94912453e-05, 1.30179909e-03,
         7.17953662e-04, 2.49015011e-05, 2.79264845e-04, 1.18639169e-03,
         1.05882655e-05, 3.46929795e-04, 1.56726083e-03, 5.90059935e-05,
         4.04924125e-04, 1.03173603e-04, 5.08172030e-04, 3.59937330e-05,
         4.55926260e-04, 4.30797400e-06, 2.22485582e-03, 9.14659540e-05,
         5.86152717e-04, 1.89333805e-04, 2.28657682e-05, 9.02368083e-06,
         4.56147594e-04, 3.44344560e-04, 4.00234276e-05, 5.61897104e-05,
         3.82362487e-05, 2.51189540e-05, 5.35444997e-05, 2.91154283e-04,
         5.84921418e-06, 1.06375708e-04, 4.15303773e-0

In [16]:
hubertbase_12_accuracy

{'accuracy': 0.938121546961326}

### 9 layer

In [21]:
hubertbase_9

{'id10001-Y8hIVOBuels-00001': {'labels': array([0]),
  'probabilities': array([6.41823590e-01, 5.03017801e-11, 3.99939881e-06, 1.67128237e-04,
         3.23117971e-02, 4.08402266e-05, 9.96526417e-09, 3.46162494e-07,
         3.88069847e-03, 1.09123102e-04, 7.37607252e-07, 2.14458139e-12,
         2.70939017e-06, 4.10376011e-10, 9.40262090e-09, 4.50690929e-10,
         2.62221289e-09, 3.53751730e-08, 8.73050976e-06, 2.76035371e-05,
         5.07640943e-05, 1.05628983e-09, 9.77194140e-05, 1.13263830e-07,
         3.01933166e-11, 1.75469811e-03, 4.56229376e-04, 4.52008297e-09,
         4.45697346e-08, 2.72828174e-06, 1.75812937e-04, 4.20413204e-10,
         1.78296256e-09, 1.02390691e-08, 3.47434059e-02, 1.75148773e-09,
         5.92689204e-04, 2.49733694e-05, 2.19645671e-08, 1.49324073e-08,
         4.52057225e-10, 1.30667900e-06, 3.71234563e-08, 4.13681647e-07,
         1.01976683e-09, 1.90482012e-08, 4.85473919e-08, 1.09542566e-06,
         2.16581530e-10, 8.19829802e-05, 4.55291598e-0

In [20]:
hubertbase_9_accuracy

{'accuracy': 0.9050632911392406}

### 6 layer

In [31]:
hubertbase_6

{'id10001-Y8hIVOBuels-00001': {'labels': array([0]),
  'probabilities': array([3.94152164e-01, 4.84396696e-05, 2.61282752e-04, 6.72647788e-04,
         1.38084926e-02, 3.24053777e-04, 7.41351323e-05, 1.01796350e-05,
         4.01328243e-02, 1.19098579e-03, 5.48485725e-04, 1.22963165e-05,
         2.25877942e-04, 1.65788351e-05, 6.58127305e-04, 2.21315750e-05,
         1.69755840e-05, 2.46690906e-04, 5.94912453e-05, 1.30179909e-03,
         7.17953662e-04, 2.49015011e-05, 2.79264845e-04, 1.18639169e-03,
         1.05882655e-05, 3.46929795e-04, 1.56726083e-03, 5.90059935e-05,
         4.04924125e-04, 1.03173603e-04, 5.08172030e-04, 3.59937330e-05,
         4.55926260e-04, 4.30797400e-06, 2.22485582e-03, 9.14659540e-05,
         5.86152717e-04, 1.89333805e-04, 2.28657682e-05, 9.02368083e-06,
         4.56147594e-04, 3.44344560e-04, 4.00234276e-05, 5.61897104e-05,
         3.82362487e-05, 2.51189540e-05, 5.35444997e-05, 2.91154283e-04,
         5.84921418e-06, 1.06375708e-04, 4.15303773e-0

In [22]:
hubertbase_6_accuracy

{'accuracy': 0.8924050632911392}

## hubertbase ensemble

In [23]:
num_models = 3
ensemble_test_dict = hubertbase_6.copy()
key = 'id10001-Y8hIVOBuels-00001'

for key in hubertbase_6.keys():
    mean_prob = np.asarray([sum(x) for x in zip(hubertbase_6[key]['probabilities'], hubertbase_9[key]['probabilities'], hubertbase_12[key]['probabilities'])])/num_models
    ensemble_test_dict[key]['probabilities'] = mean_prob
    ensemble_test_dict[key]['predictions'] = np.argmax(mean_prob)
    metric.add(prediction=ensemble_test_dict[key]['predictions'], reference = ensemble_test_dict[key]['labels'])

accuracy = metric.compute()

In [24]:
accuracy

{'accuracy': 0.9731012658227848}

## Top 3 Probabilities

In [26]:
def get_top3_accuracy(out_dict):
    test_length = 632
    correct_count=0
    for key in out_dict.keys():
        val=out_dict[key]
        out_dict[key]['top3'] = []
        top_3_out_dict=sorted(zip(val['probabilities'],range(len(val['probabilities']))), reverse = True)[:3]
        out_dict[key]['top3'] = top_3_out_dict
    for key in out_dict.keys():
        top3_labels = [x[1] for x in out_dict[key]['top3']]
        label = out_dict[key]['labels'][0]
        if (label in top3_labels):
            correct_count+=1
            
    return out_dict,correct_count/test_length


hubertbase_12,top3_hubertbase_12_acc = get_top3_accuracy(hubertbase_12)
hubertbase_9,top3_hubertbase_9_acc = get_top3_accuracy(hubertbase_9)
hubertbase_6,top3_hubertbase_6_acc = get_top3_accuracy(hubertbase_6)




In [28]:
print(top3_hubertbase_12_acc,top3_hubertbase_9_acc,top3_hubertbase_6_acc)

0.995253164556962 0.9810126582278481 0.9984177215189873


In [38]:
lengths = []
for i in range(len(dataset['test']['id'])):
    length = len(dataset['test']['audio'][i]['array'])
    lengths.append(length/16000)
    
with open('test_lengths.txt', 'w') as f:
    for line in lengths:
        f.write(f"{line}\n")

def append_audio_lengths(out_dict):
    lengths = []
    with open('test_lengths.txt','r') as file:
        lengths = [float(line.rstrip()) for line in file]
    for length,key in zip(lengths,hubertbase_12.keys()):
        out_dict[key]['audio_length'] = length
        
    return out_dict

In [39]:
hubertbase_6 = append_audio_lengths(hubertbase_6)
hubertbase_9 = append_audio_lengths(hubertbase_9)
hubertbase_12 = append_audio_lengths(hubertbase_12)

In [40]:
def get_mistake_ids(out_dict):
    for key in out_dict.keys():
        if 

{'id10001-Y8hIVOBuels-00001': {'labels': array([0]),
  'probabilities': array([7.18385756e-01, 5.99877110e-08, 1.21949598e-04, 1.20490335e-04,
         2.30489230e-02, 1.83974116e-04, 9.69476806e-07, 2.79337685e-07,
         2.98068824e-03, 1.09419358e-04, 2.16083677e-06, 5.72366890e-10,
         1.02991708e-05, 6.05620046e-07, 6.41721129e-06, 1.46574409e-07,
         2.61478943e-06, 3.79092161e-06, 3.45806618e-06, 6.30345206e-04,
         2.33164651e-03, 6.66117387e-09, 3.63146925e-04, 8.42026143e-06,
         1.42402377e-08, 8.03792645e-04, 3.44196627e-03, 1.76949976e-07,
         4.67411405e-06, 1.91327090e-06, 7.65050772e-05, 1.01854133e-07,
         2.45616293e-06, 8.56446265e-08, 1.18324785e-02, 2.49904924e-07,
         2.92334936e-04, 4.23379801e-05, 2.88495286e-07, 1.32999812e-07,
         1.12255183e-08, 4.18247127e-05, 3.64129665e-08, 1.54038774e-07,
         2.11674395e-07, 7.00675050e-08, 8.24365072e-07, 4.53383414e-07,
         4.84905605e-09, 3.49652972e-05, 8.52798391e-0