## EPAlign Prompt and Audio Finetune

### Config

In [58]:
import os
import clip
import numpy as np
import librosa
from torch.utils.data import Dataset
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, BatchSampler
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)
from transformers import Wav2Vec2Processor
import logging
from tqdm import tqdm

# DATASET is the dataset name model trained on, e.g. ESD, MELD
DATASET = "ESD" # MELD

# BATCH_SIZE should smaller/equal to the category of the emotion, e.g. for ESD, the category is 5
BATCH_SIZE = 5
EPOCH = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

PROJECT_PATH = os.path.join('/', *os.getcwd().split(os.sep)[:-2])
# PROCESSED_WAV2VEC2_PATH is the path to the Wav2Vec2Processor
PROCESSED_WAV2VEC2_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base/wav2vec2"
# PRETRAIN_WAV2VEC2_PATH is the pretrained model path, e.g. EPAlign/ckpt/base/wav2vec2
PRETRAIN_WAV2VEC2_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base/wav2vec2"
# ESD_FILELIST_PATH is the path to the ESD filelist
ESD_FILELIST_PATH = f"{PROJECT_PATH}/EMITTS/filelist/{DATASET}"
# PRETRAIN_CLIP_MODEL is the pretrained CLIP model, e.g. ViT-B-32
PRETRAIN_CLIP_MODEL = "ViT-B/32"
# PRETRAIN_CLIP_MODEL_PATH is the pretrained model path, e.g. EPAlign/ckpt/base
PRETRAIN_CLIP_MODEL_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base"
# LOG_PATH is the log path, e.g. EPAlign/log
LOG_PATH = f"{PROJECT_PATH}/EPAlign/log"
# CKPT_PATH is the path to save checkpoint, e.g. EPAlign/ckpt/ESD
CKPT_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}"

### Prompt Audio Model (consist of language model and acoustic model)

In [None]:
class CLAP(Wav2Vec2PreTrainedModel):
    def __init__(self, config, prompt_pretrain_model, prompt_pretrain_model_path):
        super().__init__(config)
        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.init_weights()
        width = 1024
        scale = width ** -0.5
        self.proj = nn.Parameter(scale * torch.randn(width, 512))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.prompt_model, self.prompt_processor = clip.load(prompt_pretrain_model, jit=False, download_root=prompt_pretrain_model_path)
        self.prompt_model.to(device)
    def forward(self, wavs, prompts):
        audio_features = torch.tensor([]).to(device)
        for wav in wavs:
            audio_feature = self.wav2vec2(wav)
            audio_feature = audio_feature[0]
            audio_feature = torch.mean(audio_feature, dim=1)
            audio_features = torch.cat((audio_features, audio_feature), dim=0)
        audio_features = audio_features @ self.proj

        prompt_features = clip.tokenize(prompts).to(device)
        prompt_features = self.prompt_model.encode_text(prompt_features)
        # normalized features
        audio_features = audio_features / audio_features.norm(dim=1, keepdim=True)
        prompt_features = prompt_features / prompt_features.norm(dim=1, keepdim=True)
        audio_features = audio_features.float()
        prompt_features = prompt_features.float()

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp().float()
        logits_per_audio = logit_scale * audio_features @ prompt_features.t()
        logits_per_text = logits_per_audio.t()
        return logits_per_audio, logits_per_text

model = CLAP.from_pretrained(PRETRAIN_WAV2VEC2_PATH, prompt_pretrain_model=PRETRAIN_CLIP_MODEL, prompt_pretrain_model_path=PRETRAIN_CLIP_MODEL_PATH).to(device)

### Use Wav2Vec2 Processor

In [60]:
processor = Wav2Vec2Processor.from_pretrained(PROCESSED_WAV2VEC2_PATH)

### Define Dataset

In [61]:
class ESDDataset(Dataset):
    def __init__(self,
                 datalist="path/to/datalist",
                 preprocess=None):
        self.datalist = datalist
        self.preprocess = preprocess
        self.data = self.load_data()
        self.text2label = {
            "angry": 1,
            "happy": 2,
            "neutral": 3,
            "sad": 4,
            "surprise": 5,
        }
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        wav_path = data[0]
        wav, _ = librosa.load(wav_path, sr=16000)
        # audio = torch.from_numpy(wav).float()
        if self.preprocess is not None:
            audio = self.preprocess(wav, sampling_rate=16000)
            audio = audio["input_values"][0]
            audio = audio.reshape(1, -1)
            audio = torch.from_numpy(audio).to(device).float()

        prompt_feature_path = data[3]
        emotiontag = prompt_feature_path.split("/")[-1].split(".")[0]
        prompt = f"A person speaking with a feeling of {emotiontag}"
        label = self.text2label[emotiontag]
    
        return audio, prompt, label
    
    def load_data(self):
        with open(self.datalist, encoding='utf-8') as f:
            data = [line.strip().split("|") for line in f]
        return data
    
if DATASET == "ESD":
    train_dataset = ESDDataset(datalist=f'{ESD_FILELIST_PATH}/esd_en_audio_sid_text_efeature_train_filelist.txt', preprocess=processor)
    val_dataset = ESDDataset(datalist=f'{ESD_FILELIST_PATH}/esd_en_audio_sid_text_efeature_val_filelist.txt', preprocess=processor)
    test_dataset = ESDDataset(datalist=f'{ESD_FILELIST_PATH}/esd_en_audio_sid_text_efeature_test_filelist.txt', preprocess=processor)

assert len(train_dataset) == 14_000
assert len(val_dataset) == 1_750
assert len(test_dataset) == 1_750

### Define Batch Sample (ensures no same class per batch)

In [63]:
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

def collate_fn(batch):
    audios = [sample[0] for sample in batch]
    prompts = [sample[1] for sample in batch]
    labels = [sample[2] for sample in batch]
    labels = torch.tensor(labels).to(device)

    return audios, prompts, labels

train_labels = torch.tensor([item[2] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn)

test_labels = torch.tensor([item[2] for item in test_dataset])
test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler, collate_fn=collate_fn)

### Train Config

In [64]:
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

loss_audio = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
# optimizer = optim.Adam(model.parameters(), lr=1e-5)
optimizer = optim.Adam([model.proj, model.logit_scale], lr=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)


### Train Log

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

file_handler = logging.FileHandler(f"{LOG_PATH}/log_prompt_audio_{DATASET}.txt")

file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

log = logging.getLogger('')
log.addHandler(file_handler)
log.info('finetune start...')

### Train

In [None]:
best_te_loss = 1e5
best_ep = -1
for epoch in range(EPOCH):
    print(f"running epoch {epoch}, best test loss {best_te_loss} after epoch {best_ep}")
    step = 0
    tr_loss = 0
    model.train()
    pbar = tqdm(train_dataloader, leave=False)
    for batch in pbar:
        step += 1
        optimizer.zero_grad()

        audios, prompts, _ = batch
        logits_per_audio, logits_per_text = model(audios, prompts)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        total_loss = (loss_audio(logits_per_audio,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        total_loss.backward()
        tr_loss += total_loss.item()
        optimizer.step()
        scheduler.step()
        pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
    tr_loss /= step
    
    step = 0
    te_loss = 0
    with torch.no_grad():
        model.eval()
        test_pbar = tqdm(test_dataloader, leave=False)
        for batch in test_pbar:
            step += 1
            audios, texts, _ = batch
            logits_per_audio, logits_per_text = model(audios, texts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_audio(logits_per_audio,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            te_loss += total_loss.item()
            test_pbar.set_description(f"test batchCE: {total_loss.item()}", refresh=True)
        te_loss /= step
        
    if te_loss < best_te_loss:
        best_te_loss = te_loss
        best_ep = epoch
        torch.save(model.state_dict(), f"{CKPT_PATH}/best_model_proj_logit.pt")
        torch.save(model.prompt_model.state_dict(), f"{CKPT_PATH}/best_model.pt")
    print(f"epoch {epoch}, tr_loss {tr_loss}, te_loss {te_loss}")
    # torch.save(model.state_dict(), f"{CKPT_PATH}/ESD_ft_proj_logit_{epoch}_model.pt")
    # torch.save(model.prompt_model.state_dict(), f'{CKPT_PATH}/model_{epoch}.pt')