In [21]:
from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification, Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig
from transformers import default_data_collator
from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
import torch
import torch.nn as nn
from datasets import Dataset, Audio, Value, Features,load_dataset,ClassLabel
from transformers import Wav2Vec2Processor
from transformers.modeling_outputs import SequenceClassifierOutput
import numpy as np
from transformers import AdamW,get_scheduler
from datasets import load_metric
from tqdm.auto import tqdm
import os
from torch.utils.data import DataLoader
from datasets import load_metric


def featurize(batch):
#     audio_arrays = [batch['audio'][i]['array'] for i in range(len(batch))]
    audio_arrays = [batch['audio'][i]['array'] for i in range(len(batch['id']))]
#     print(len(audio_arrays))
    inputs = feature_extractor(
        audio_arrays, 
        sampling_rate=16_000, 
        max_length=int(16_000 * 10),  # 10s
        truncation=True, 
        padding='max_length',
    )
    return inputs


checkpoint = "facebook/wav2vec2-base"
weighted_sum = False
# checkpoint = "facebook/hubert-base-ls960"
# checkpoint = "facebook/wav2vec2-large-lv60"

x = [str(i) for i in range(0,100,1)]
features = Features(
    {
        "id": Value("string"),
        "speaker_id": Value("string"), 
        'path': Value('string'),
        "audio": Audio(sampling_rate=16000),
        "label": ClassLabel(num_classes=100,names=x,names_file=None,id=None)
    }
)


dataset = load_dataset('csv', 
                       data_files={'test': 'test_100.csv'},
                       features=features)
dataset = dataset.map(remove_columns=(['path','speaker_id']),num_proc=24)
dataset = dataset.sort("label")
# sampling_rate = dataset.features["audio"].sampling_rate
if 'base' in checkpoint and not weighted_sum:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint,return_attention_mask=False)
else:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint,return_attention_mask=True)


dataset = dataset.map(featurize, remove_columns='audio',batched=True,num_proc=20,batch_size=1)

# dataset = dataset.rename_column('id','label_ids')

if 'base' in checkpoint and not weighted_sum:
    dataset.set_format("torch",columns=["id","input_values", "label"])
else:
    dataset.set_format("torch",columns=["id","input_values", "attention_mask", "label"])

Using custom data configuration default-a5a36436e0efa59f
Found cached dataset csv (/storage/home/hcocice1/vprakash40/.cache/huggingface/datasets/csv/default-a5a36436e0efa59f/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 346.84it/s]
Loading cached processed dataset at /storage/home/hcocice1/vprakash40/.cache/huggingface/datasets/csv/default-a5a36436e0efa59f/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-cf588e66c1064c8a.arrow
Loading cached processed dataset at /storage/home/hcocice1/vprakash40/.cache/huggingface/datasets/csv/default-a5a36436e0efa59f/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-ae320ba30df25ac7.arrow
Loading cached processed dataset at /storage/home/hcocice1/vprakash40/.cache/huggingface/datasets/csv/default-a5a36436e0efa59f/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-92e982191eb635ff.arrow
Loading cached processed d

### Model

In [22]:
class CustomBaseSID(nn.Module):
    def __init__(self,checkpoint,num_labels,inter_layer_num,attend):
        
        ### attend is a boolean
        super(CustomBaseSID, self).__init__()
#         self.hubert = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid")
        self.model =AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=attend,output_hidden_states=True))
        self.num_labels = num_labels
        self.attend_mask = attend
        ### New layers:
        self.linear1 = nn.Linear(768, 1024)
        self.linear2 = nn.Linear(1024, num_labels)
        
        ### Intermediate Layer Number
        self.layer_num = inter_layer_num

    def forward(self, input_values=None, attention_mask=None,labels=None):
        if(self.attend_mask):
            outputs = self.model(input_values=input_values, attention_mask=attention_mask)
        else:
            outputs = self.model(input_values=input_values, attention_mask=None)
        feature = outputs.hidden_states[self.layer_num-1]
        agg_vec_list = []
        for i in range(len(feature)):
            if(attention_mask==None):
                length = len(feature[i])
            
            else:
                if torch.nonzero(attention_mask[i] < 0, as_tuple=False).size(0) == 0:
                    length = len(feature[i])
                else:
                    length = torch.nonzero(attention_mask[i] < 0, as_tuple=False)[0] + 1
            agg_vec=torch.mean(feature[i][:length], dim=0)
            agg_vec_list.append(agg_vec)
        mean = torch.stack(agg_vec_list)
        # sequence_output has the following shape: (batch_size, sequence_length, 768)
        linear1_output = self.linear1(mean) ## extract the 1st token's embeddings
        logits = self.linear2(linear1_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)


In [23]:
def freeze_layers_transformer(model_ft,keywords,inter_layer,encoder_mode=True):
    ct = 0
    for child in model_ft.children():
        for name,param in child.named_parameters():
            if(encoder_mode):
                if 'encoder' in name and str(ct) in name:
                    ct += 1
                    if ct < inter_layer:
                        param.requires_grad = False
                    
            for word in keywords:         
                if word in name:
                    param.requires_grad = False
                            
    return model_ft

In [24]:
dataset

DatasetDict({
    test: Dataset({
        features: ['id', 'label', 'input_values'],
        num_rows: 632
    })
})

In [32]:
inter_layer = 12
dirpath="/storage/home/hcocice1/vprakash40/statmlfinal/wav2vec2-base/"+str(inter_layer)+'/'
PATH = dirpath+str(4)+'.pt'
checkpoint = "facebook/wav2vec2-base"

if 'base' in checkpoint:
    attend = False

device = "cpu"
model_ft = CustomBaseSID(checkpoint=checkpoint,num_labels=100,inter_layer_num=inter_layer,attend=attend)
#torch.save(model_ft.state_dict(), PATH)
model_ft.load_state_dict(torch.load(PATH, map_location=device))

metric_name="accuracy"
metric = load_metric(metric_name)
test_dataloader = DataLoader(dataset["test"], batch_size=1, collate_fn=default_data_collator)


Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['quantizer.codevectors', 'project_hid.weight', 'project_hid.bias', 'quantizer.weight_proj.bias', 'project_q.bias', 'quantizer.weight_proj.weight', 'project_q.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [33]:
# keys = list(dataset['test'].features.keys())
keys = list(dataset['test']['id'])
test_dict = dict.fromkeys(keys)

In [34]:
list(test_dict.keys())[0]

'id10001-Y8hIVOBuels-00001'

In [35]:
def evaluate(test_dataloader,model_ft):
    test_dict = dict.fromkeys(list(dataset['test']['id']))
    keys = list(test_dict.keys())
    
    for i,batch in enumerate(test_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model_ft(**batch)
        logits = outputs.logits
        sm = torch.nn.Softmax(dim=1)
        probabilities = sm(logits)
        test_dict[keys[i]] = dict.fromkeys(['labels','probabilities','predictions'])
        test_dict[keys[i]]['labels'] = batch['labels'].detach().numpy()
        test_dict[keys[i]]['probabilities'] = probabilities[0].detach().numpy()
        predictions = torch.argmax(logits, dim=-1)
        test_dict[keys[i]]['predictions'] = predictions[0].detach().numpy()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    accuracy = metric.compute()
    return test_dict,accuracy

In [78]:
w2vbase_12,w2vbase_12_acuraccy = evaluate(test_dataloader,model_ft)

In [79]:
w2vbase_12

{'id10001-Y8hIVOBuels-00001': {'labels': array([0]),
  'probabilities': array([3.94151866e-01, 4.84396769e-05, 2.61282694e-04, 6.72647904e-04,
         1.38084879e-02, 3.24053661e-04, 7.41350741e-05, 1.01796277e-05,
         4.01328281e-02, 1.19098602e-03, 5.48486074e-04, 1.22963302e-05,
         2.25878000e-04, 1.65788388e-05, 6.58127130e-04, 2.21315986e-05,
         1.69755876e-05, 2.46690965e-04, 5.94912599e-05, 1.30180118e-03,
         7.17953430e-04, 2.49015284e-05, 2.79264757e-04, 1.18639204e-03,
         1.05882773e-05, 3.46929708e-04, 1.56726106e-03, 5.90060044e-05,
         4.04924183e-04, 1.03173617e-04, 5.08172147e-04, 3.59937367e-05,
         4.55926143e-04, 4.30797490e-06, 2.22485629e-03, 9.14659686e-05,
         5.86152542e-04, 1.89333834e-04, 2.28657736e-05, 9.02368265e-06,
         4.56147711e-04, 3.44344618e-04, 4.00233985e-05, 5.61897177e-05,
         3.82362923e-05, 2.51189595e-05, 5.35445033e-05, 2.91154196e-04,
         5.84922054e-06, 1.06375737e-04, 4.15303861e-0

In [80]:
w2vbase_12_acuraccy

{'accuracy': 0.9066455696202531}

In [99]:
for keys in w2vbase_12.keys():
    val=w2vbase_12[keys]
    top_3=sorted(zip(val['probabilities'],range(len(val['probabilities']))), reverse = True)[:3]
    print(top_3)

[(0.39415187, 0), (0.3919476, 89), (0.045318536, 65)]
[(0.5778379, 0), (0.17136477, 8), (0.09431505, 89)]
[(0.4999907, 0), (0.13617863, 76), (0.11626463, 4)]
[(0.7249681, 0), (0.08266102, 8), (0.040511295, 89)]
[(0.7790009, 0), (0.046819795, 8), (0.029893367, 89)]
[(0.70961976, 0), (0.05222873, 8), (0.040192958, 4)]
[(0.82332355, 0), (0.033286825, 65), (0.025629457, 81)]
[(0.79613066, 0), (0.03129144, 8), (0.022422522, 4)]
[(0.815222, 0), (0.0345636, 89), (0.023978433, 8)]
[(0.77257663, 0), (0.072416246, 81), (0.03961929, 89)]
[(0.9770155, 1), (0.0049822247, 2), (0.0033435107, 16)]
[(0.9782904, 1), (0.0038638574, 2), (0.0037046752, 16)]
[(0.97875696, 1), (0.003588938, 16), (0.0035391343, 2)]
[(0.9768908, 1), (0.005765291, 2), (0.0033948992, 16)]
[(0.6252126, 1), (0.33847442, 2), (0.009967675, 16)]
[(0.9786678, 1), (0.0038346346, 2), (0.003816273, 16)]
[(0.97822034, 1), (0.0036112233, 2), (0.0032002514, 16)]
[(0.978228, 1), (0.003692672, 2), (0.002905657, 16)]
[(0.9848305, 2), (0.002242

In [83]:
w2vbase_12

{'id10001-Y8hIVOBuels-00001': {'labels': array([0]),
  'probabilities': array([3.94151866e-01, 4.84396769e-05, 2.61282694e-04, 6.72647904e-04,
         1.38084879e-02, 3.24053661e-04, 7.41350741e-05, 1.01796277e-05,
         4.01328281e-02, 1.19098602e-03, 5.48486074e-04, 1.22963302e-05,
         2.25878000e-04, 1.65788388e-05, 6.58127130e-04, 2.21315986e-05,
         1.69755876e-05, 2.46690965e-04, 5.94912599e-05, 1.30180118e-03,
         7.17953430e-04, 2.49015284e-05, 2.79264757e-04, 1.18639204e-03,
         1.05882773e-05, 3.46929708e-04, 1.56726106e-03, 5.90060044e-05,
         4.04924183e-04, 1.03173617e-04, 5.08172147e-04, 3.59937367e-05,
         4.55926143e-04, 4.30797490e-06, 2.22485629e-03, 9.14659686e-05,
         5.86152542e-04, 1.89333834e-04, 2.28657736e-05, 9.02368265e-06,
         4.56147711e-04, 3.44344618e-04, 4.00233985e-05, 5.61897177e-05,
         3.82362923e-05, 2.51189595e-05, 5.35445033e-05, 2.91154196e-04,
         5.84922054e-06, 1.06375737e-04, 4.15303861e-0