In [1]:
import jax.numpy as jnp
import jax
if jax.device_count()>1:
    dtype = jnp.bfloat16
    num_workers = 16
    try:
        import transformers
        speech_path = 'bengaliai-speech/train_mp3s/'
        data_path = 'bengaliai-speech/train.csv'
    except:
        !pip install librosa --quiet
        !pip install git+https://github.com/zhenlan0426/transformers --quiet
        !pip install evaluate --quiet
        !pip install sentencepiece --quiet
        !pip install jiwer --quiet
        speech_path = '/kaggle/input/bengaliai-speech/train_mp3s/'
        data_path = '/kaggle/input/bengaliai-speech/train.csv'
        
else:
    dtype = jnp.float16
    speech_path = 'data/train_mp3s/'
    data_path = 'data/train.csv'
    num_workers = 8

In [2]:
# from whisper_jax import FlaxWhisperForConditionalGeneration
from transformers import FlaxWhisperForConditionalGeneration
from functions import *
from functools import partial
import optax
import evaluate
from jax import random
from transformers import AutoTokenizer

In [3]:
batch_size = 16
pad_to_multiple_of = 4
max_length_gen = 48
epochs = 1
verbose = 625


# tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v2", language="bn", task="transcribe")
tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglat5")
tokenizer.bos_token = tokenizer.bos_token_id = None
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v2")
text = pd.read_csv(data_path)

In [4]:
dataset = AudioDataset(text.iloc[:950000],speech_path)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, \
                        collate_fn=partial(collate_fn,tokenizer=tokenizer,feature_extractor=feature_extractor,pad_to_multiple_of=pad_to_multiple_of,IsTrain=True))

dataset = AudioDataset(text.iloc[950000:],speech_path)
test_loader = DataLoader(dataset, batch_size=batch_size*4, shuffle=False, num_workers=num_workers, \
                        collate_fn=partial(collate_fn,tokenizer=tokenizer,feature_extractor=feature_extractor,pad_to_multiple_of=pad_to_multiple_of,IsTrain=True))

In [5]:
audio,input_ids,attention_mask = next(iter(train_loader))
audio,input_ids,attention_mask = jnp.array(audio,dtype=dtype),jnp.array(input_ids),jnp.array(attention_mask)

In [6]:
# load the processor and model
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-large-v2", dtype=dtype, _do_init=False,)
model.config.forced_decoder_ids = None
model.config.bos_token_id = None
model.config.suppress_tokens = None
model.config.decoder_start_token_id = None
model.generation_config.decoder_start_token_id = [50258, 50302, 50359, 50363]# '<|startoftranscript|><|bn|><|transcribe|><|notimestamps|>
model.generation_config.forced_decoder_ids = None
"""A list of pairs of integers which indicates a mapping from generation indices to token indices 
that will be forced before sampling. For example, [[0, 123]] means the first generated token 
will always be a token of index 123."""
model.generation_config.suppress_tokens = None
model.generation_config.begin_suppress_tokens = None
model.generation_config.bos_token_id = None

In [7]:
# ensure std of init is the same
std_ = params['model']['decoder']['embed_tokens']['embedding'].std().item()
# reset the embedding params
params['model']['decoder']['embed_tokens']['embedding'] = params['model']['decoder']['embed_tokens']['embedding'].at[:tokenizer.vocab_size]\
                                                                .set(random.normal(random.PRNGKey(7),(tokenizer.vocab_size,model.config.d_model)) * std_)
embedding = params['model']['decoder']['embed_tokens']['embedding']

In [8]:
opt = optax.adamw(learning_rate=1e-3)
opt_states = opt.init(embedding)
#opt_states = opt.init(params)

In [9]:
# Generation
# https://huggingface.co/transformers/v4.1.1/_modules/transformers/generation_logits_process.html
# https://huggingface.co/docs/transformers.js/api/utils/generation#module_utils/
# LogitsProcessor (input_ids: LongTensorscores with shape (batch, length), Score: FloatTensor ) → NewScore: FloatTensor with shape (batch_size, config.vocab_size)

In [10]:
@jax.jit
def train_one_step_embed(embedding,params,audio,input_ids,attention_mask,opt_states):
    def loss_fn(embedding,params,audio,input_ids,attention_mask):
        params['model']['decoder']['embed_tokens']['embedding'] = embedding
        out = model(audio,input_ids,decoder_attention_mask=attention_mask,params=params,train=True).logits # (B, L, d)
        return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(out[:,3:-1], input_ids[:,4:])*attention_mask[:,4:])
    grad_fn = jax.value_and_grad(loss_fn,has_aux=False)
    l,grads = grad_fn(embedding,params,audio,input_ids,attention_mask)
    updates, opt_states = opt.update(grads, opt_states,params=embedding)
    embedding = optax.apply_updates(embedding, updates)
    return embedding,opt_states,l

@jax.jit
def eval_one_step(params,audio,input_ids,attention_mask):
    out = model(audio,input_ids,decoder_attention_mask=attention_mask,params=params,train=False).logits # (B, L, d)
    return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(out[:,3:-1], input_ids[:,4:])*attention_mask[:,4:])

@jax.jit
def train_one_step(params,audio,input_ids,attention_mask,opt_states):
    def loss_fn(params,audio,input_ids,attention_mask):
        out = model(audio,input_ids,decoder_attention_mask=attention_mask,params=params,train=True).logits # (B, L, d)
        return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(out[:,3:-1], input_ids[:,4:])*attention_mask[:,4:])
    grad_fn = jax.value_and_grad(loss_fn,has_aux=False)
    l,grads = grad_fn(params,audio,input_ids,attention_mask)
    updates, opt_states = opt.update(grads, opt_states,params=params)
    params = optax.apply_updates(params, updates)
    return params,opt_states,l

# @jax.jit
# def generate(params,audio):
#     return model.generate(audio,params=params,max_length=max_length_gen)

metric = evaluate.load("wer")
def metric_one_step(params,audio,txt):
    generated_ids = model.generate(audio,params=params,max_length=max_length_gen, num_beams=1, do_sample=False).sequences
    transcriptions = tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
    wer = metric.compute(predictions=transcriptions, references=txt)
    return wer

def batch_generate(loader):
    pass
    #transcriptions = [txt + "|" for txt in transcriptions]

In [None]:
for i in range(epochs):
    # train
    train_loss = 0
    for j,(audio,input_ids,attention_mask) in enumerate(train_loader):
        audio,input_ids,attention_mask = jnp.array(audio,dtype=dtype),jnp.array(input_ids),jnp.array(attention_mask)
        embedding,opt_states,l = train_one_step_embed(embedding,params,audio,input_ids,attention_mask,opt_states)
        train_loss += l.item()
        
        # eval
        if j%verbose == 0:
            train_loss /= verbose    
            eval_metric = 0
            params['model']['decoder']['embed_tokens']['embedding'] = embedding
            for k,(audio,input_ids,attention_mask) in enumerate(test_loader):
                audio,input_ids,attention_mask = jnp.array(audio,dtype=dtype),jnp.array(input_ids),jnp.array(attention_mask)
                eval_l = eval_one_step(params,audio,input_ids,attention_mask)
                eval_metric += eval_l.item()
            
            eval_metric /= k
            print(f"iterations:{j}, loss: {train_loss:.3f}, wer: {eval_metric:.3f}")
            train_loss = 0