### Author: Vignesh Srinivasa Naidu Prakash
### GTID: 903809799

### Dataset

In [1]:
from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification, Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig
from transformers import default_data_collator
from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
import torch
import torch.nn as nn
from datasets import Dataset, Audio, Value, Features,load_dataset,ClassLabel
from transformers import Wav2Vec2Processor
from transformers.modeling_outputs import SequenceClassifierOutput
import numpy as np
from transformers import AdamW,get_scheduler
from datasets import load_metric
from tqdm.auto import tqdm
import os
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = "facebook/wav2vec2-base"
weighted_sum = True
# checkpoint = "facebook/hubert-base-ls960"
# checkpoint = "facebook/wav2vec2-large-lv60"

x = [str(i) for i in range(0,100,1)]
features = Features(
    {
        "id": Value("string"),
        "speaker_id": Value("string"), 
        'path': Value('string'),
        "audio": Audio(sampling_rate=16000),
        "label": ClassLabel(num_classes=100,names=x,names_file=None,id=None)
    }
)


dataset = load_dataset('csv', 
                       data_files={'train': '../data/identification/train_100.csv',
                                    'dev':'../data/identification/dev_100.csv', 
                                    'test': '../data/identification/test_100.csv'},
                       features=features)
dataset = dataset.map(remove_columns=(['path','speaker_id']),num_proc=24)
dataset = dataset.sort("label")
sampling_rate = dataset["train"].features["audio"].sampling_rate
if 'base' in checkpoint and not weighted_sum:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint,return_attention_mask=False)
else:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint,return_attention_mask=True)

Using custom data configuration default-32600726bfa6b6de
Found cached dataset csv (/storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 3/3 [00:00<00:00, 347.88it/s]
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-2c523a6aa4d437d1.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-752369dc0b7f5663.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-6649bf103c92c53a.arrow
Loading cached processed dataset at /s

In [3]:
dataset['dev']

Dataset({
    features: ['id', 'audio', 'label'],
    num_rows: 533
})

In [3]:
def featurize(batch):
#     audio_arrays = [batch['audio'][i]['array'] for i in range(len(batch))]
    audio_arrays = [batch['audio'][i]['array'] for i in range(len(batch['id']))]
    print(len(audio_arrays))
    inputs = feature_extractor(
        audio_arrays, 
        sampling_rate=16_000, 
        max_length=int(16_000 * 10),  # 10s
        truncation=True, 
        padding='max_length',
    )
    return inputs
dataset = dataset.map(featurize, remove_columns='audio',batched=True,num_proc=20)

if 'base' in checkpoint and not weighted_sum:
    dataset.set_format("torch",columns=["id","input_values", "label"])
else:
    dataset.set_format("torch",columns=["id","input_values", "attention_mask", "label"])
train_data_collator = default_data_collator
dev_data_collator = default_data_collator



train_dataloader = DataLoader(
    dataset["train"], shuffle=True, batch_size=32, collate_fn=train_data_collator
)
dev_dataloader = DataLoader(
    dataset["dev"], batch_size=32, collate_fn=dev_data_collator
)

Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-5055f45db9aefa06.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-a1ba8dec82a007b6.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-4fc3e9fb787afaa2.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/default-32600726bfa6b6de/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-f10b5955abda65be.arrow
Loading cached processed dataset at /storage/home/hcocice1/vkotra3/.cache/huggingface/datasets/csv/defau

### Model

In [4]:
class CustomBaseSID(nn.Module):
    def __init__(self,checkpoint,num_labels,inter_layer_num,attend):
        
        ### attend is a boolean
        super(CustomBaseSID, self).__init__()
#         self.hubert = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid")
        self.model =AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=attend,output_hidden_states=True))
        self.num_labels = num_labels
        self.attend_mask = attend
        ### New layers:
        self.linear1 = nn.Linear(768, 1024)
        self.linear2 = nn.Linear(1024, num_labels)
        
        ### Intermediate Layer Number
        self.layer_num = inter_layer_num

    def forward(self, input_values=None, attention_mask=None,labels=None):
        if(self.attend_mask):
            outputs = self.model(input_values=input_values, attention_mask=attention_mask)
        else:
            outputs = self.model(input_values=input_values, attention_mask=None)
        feature = outputs.hidden_states[self.layer_num-1]
        agg_vec_list = []
        for i in range(len(feature)):
            if(attention_mask==None):
                length = len(feature[i])
            
            else:
                if torch.nonzero(attention_mask[i] < 0, as_tuple=False).size(0) == 0:
                    length = len(feature[i])
                else:
                    length = torch.nonzero(attention_mask[i] < 0, as_tuple=False)[0] + 1
            agg_vec=torch.mean(feature[i][:length], dim=0)
            agg_vec_list.append(agg_vec)
        mean = torch.stack(agg_vec_list)
        # sequence_output has the following shape: (batch_size, sequence_length, 768)
        linear1_output = self.linear1(mean) ## extract the 1st token's embeddings
        logits = self.linear2(linear1_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)
    
    
class CustomBaseWeightedSumSID(nn.Module):
    def __init__(self,checkpoint,num_labels,inter_layer_num,attend,input_dim):
        
        ### attend is a boolean
        super(CustomBaseWeightedSumSID, self).__init__()
#         self.hubert = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid")
        self.model =AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
        self.ds_factor = 320
        self.num_labels = num_labels
        self.attend_mask = attend
        ### New layers:
        self.linear1 = nn.Linear(input_dim, 1024)
        self.linear2 = nn.Linear(1024, num_labels)
        
        ### Intermediate Layer Number
        self.layer_num = inter_layer_num
        self.W = nn.Linear(input_dim, 1)

    def forward(self, input_values=None, attention_mask=None,labels=None):
        print(attention_mask.shape)
        outputs = self.model(input_values=input_values, attention_mask=attention_mask)
        batch_rep = outputs.hidden_states[self.layer_num-1]
        attention_mask = (1.0 - attention_mask) * -100000.0
        attention_mask
        print(batch_rep.shape)
        seq_len = batch_rep.shape[1]
        softmax = nn.functional.softmax
        att_logits = self.W(batch_rep).squeeze(-1)
        print(att_logits.shape)
        att_logits = att_logits
        att_w = softmax(att_logits, dim=-1).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)
        linear1_output = self.linear1(utter_rep) ## extract the 1st token's embeddings
        logits = self.linear2(linear1_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)        
    
    

### Custom Trainer(Done!!!!)

In [18]:
def freeze_layers_transformer(model_ft,keywords,inter_layer,encoder_mode=True):
    ct = 0
    for child in model_ft.children():
        for name,param in child.named_parameters():
            if(encoder_mode):
                if 'encoder' in name and str(ct) in name:
                    ct += 1
                    if ct < inter_layer:
                        param.requires_grad = False
                    
            for word in keywords:         
                if word in name:
                    param.requires_grad = False
                            
    return model_ft

In [19]:
if 'base' in checkpoint and not weighted_sum:
    attend = False
else:
    attend = True
print(attend)
#     dataset = dataset.map(remove_columns=(['attention_mask']),num_proc=24)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inter_layer = 12
# model_ft = CustomHuBERTSID(checkpoint=checkpoint,num_labels=100,inter_layer_num=inter_layer,attend=attend).to(device)
model_ft = CustomBaseWeightedSumSID(checkpoint=checkpoint,num_labels=100,inter_layer_num=inter_layer,attend=attend,input_dim = 768).to(device)

keywords = ['spec_embed','feature_extractor','feature_projection']
model_ft = freeze_layers_transformer(model_ft,keywords,inter_layer,encoder_mode=True)

True


Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['quantizer.weight_proj.weight', 'quantizer.codevectors', 'project_q.weight', 'project_hid.bias', 'project_hid.weight', 'project_q.bias', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
metric_name="accuracy"
metric = load_metric(metric_name)

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(dev_dataloader)))
optimizer = AdamW(model_ft.parameters(), lr=1e-4)


lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)


DIRPATH = "/storage/home/hcocice1/vkotra3/6254_Project/code/w2v-base/"
PATH = DIRPATH+'attention/'
# os.mkdir(PATH)
dirpath=PATH+str(inter_layer)+'/'
# os.mkdir(dirpath)
metric_score = -10000000
for epoch in range(num_epochs):
    model_ft.train()
    for batch in train_dataloader:
#         print(batch)
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model_ft(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar_train.update(1)

        
    model_ft.eval()
    for batch in dev_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model_ft(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
        progress_bar_eval.update(1)
    
    metric_score_epoch = metric.compute()['accuracy']
    print(metric_score_epoch)
    if (metric_score_epoch>metric_score):
        torch.save(model_ft.state_dict(), dirpath+str(epoch)+'.pt')


  0%|          | 5/1740 [00:26<2:34:25,  5.34s/it]


  0%|          | 0/85 [00:26<?, ?it/s][A[A


torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 1/1740 [00:01<51:10,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 2/1740 [00:03<50:57,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 3/1740 [00:05<50:53,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 4/1740 [00:07<50:50,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 5/1740 [00:08<50:48,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 6/1740 [00:10<50:50,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 7/1740 [00:12<50:49,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  0%|          | 8/1740 [00:14<50:47,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 9/1740 [00:15<50:49,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 10/1740 [00:17<50:47,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 11/1740 [00:19<50:48,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 12/1740 [00:21<50:48,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 13/1740 [00:22<50:46,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 14/1740 [00:24<50:44,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 15/1740 [00:26<50:40,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 16/1740 [00:28<50:40,  1.76s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 17/1740 [00:29<50:44,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 18/1740 [00:31<50:43,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 19/1740 [00:33<50:43,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 20/1740 [00:35<50:39,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|          | 21/1740 [00:37<50:36,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|▏         | 22/1740 [00:38<50:37,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])



  1%|▏         | 23/1740 [00:40<50:36,  1.77s/it][A

torch.Size([32, 160000])
torch.Size([32, 499, 768])
torch.Size([32, 499])


Exception ignored in: <function Metric.__del__ at 0x2aabf98a29d0>
Traceback (most recent call last):
  File "/storage/home/hcocice1/vkotra3/.local/lib/python3.9/site-packages/datasets/metric.py", line 642, in __del__
    def __del__(self):
KeyboardInterrupt: 

KeyboardInterrupt



### Inference

In [None]:
eval_dataloader = DataLoader(
    dataset["test"], batch_size=32, collate_fn=default_data_collator
)
inter_layer = 12
dirpath="/storage/home/hcocice1/vkotra3/6254_Project/code/hubert-base/"+str(inter_layer)+'/'
PATH = dirpath+str(4)+'.pt'
checkpoint = "facebook/wav2vec2-base"
# checkpoint = "facebook/"
if 'base' in checkpoint:
    attend = False
#     dataset = dataset.map(remove_columns=(['attention_mask']),num_proc=24)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
model_ft = CustomHuBERTSID(checkpoint=checkpoint,num_labels=100,inter_layer_num=inter_layer,attend=attend)
model_ft.load_state_dict(torch.load(PATH, map_location=device))

In [13]:
from datasets import load_metric
eval_dataloader = DataLoader(
    dataset["test"], batch_size=32, collate_fn=default_data_collator
)
inter_layer = 12
dirpath="/storage/home/hcocice1/vkotra3/6254_Project/code/hubert-base/"+str(inter_layer)+'/'
PATH = dirpath+str(4)+'.pt'
checkpoint = "facebook/wav2vec2-base"
# checkpoint = "facebook/"
if 'base' in checkpoint:
    attend = False
#     dataset = dataset.map(remove_columns=(['attention_mask']),num_proc=24)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
model_ft = CustomHuBERTSID(checkpoint=checkpoint,num_labels=100,inter_layer_num=inter_layer,attend=attend)
model_ft.load_state_dict(torch.load(PATH, map_location=device))


metric_name="accuracy"
metric = load_metric(metric_name)
model_ft.eval()

test_dataloader = DataLoader(
    dataset["test"], batch_size=32, collate_fn=default_data_collator
)

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model_ft(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

{'accuracy': 0.745253164556962}