In [1]:
import torch
import math
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

模型

In [2]:
AUDIO_FEATURE_SIZE=80
D_MODEL=256
NHEAD=4
DIM_FEED_FORWARD=2048
ENCODER_NUM_LAYERS=6
DECODER_NUM_LAYERS=6
MAX_SEQ_LEN=2000

In [3]:
class PositionalEmbedding(torch.nn.Module):
    def __init__(self,dim,seq_max_len):
        super().__init__()
        position_idx=torch.arange(0,seq_max_len,dtype=torch.float).unsqueeze(-1)
        position_emb_fill=position_idx*torch.exp(-torch.arange(0,dim,2)*math.log(10000.0)/dim)
        pos_encoding=torch.zeros(seq_max_len,dim)
        pos_encoding[:,0::2]=torch.sin(position_emb_fill)
        pos_encoding[:,1::2]=torch.cos(position_emb_fill)
        self.register_buffer('pos_encoding',pos_encoding) 

    def forward(self,x):    # x: (batch_size,seq_len)
        x=x+self.pos_encoding.unsqueeze(0)[:,:x.size()[1],:] # x: (batch_size,seq_len,dim)
        return x

In [4]:
class TransformerASR(torch.nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.audio_fc=torch.nn.Sequential(
            torch.nn.Linear(in_features=AUDIO_FEATURE_SIZE,out_features=D_MODEL),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=D_MODEL,out_features=D_MODEL)
        )
        self.pos_emb=PositionalEmbedding(dim=D_MODEL,seq_max_len=MAX_SEQ_LEN)
        self.encoder=torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(d_model=D_MODEL,nhead=NHEAD,dim_feedforward=DIM_FEED_FORWARD,batch_first=True,norm_first=True),
            num_layers=ENCODER_NUM_LAYERS
        )
        self.token_emb=torch.nn.Embedding(num_embeddings=vocab_size,embedding_dim=D_MODEL)
        self.decoder=torch.nn.TransformerDecoder(
            decoder_layer=torch.nn.TransformerDecoderLayer(d_model=D_MODEL,nhead=NHEAD,dim_feedforward=DIM_FEED_FORWARD,batch_first=True,norm_first=True),
            num_layers=DECODER_NUM_LAYERS
        )
        self.prob_fc=torch.nn.Linear(D_MODEL,vocab_size)

    def encode(self,audio_features,audio_pad_mask):
        audio_features=self.audio_fc(audio_features)
        audio_features=self.pos_emb(audio_features)
        enc_out=self.encoder(audio_features,src_key_padding_mask=audio_pad_mask)
        return enc_out
    
    def decode(self,enc_out,audio_pad_mask,token_ids,token_pad_mask):
        token_embs=self.token_emb(token_ids)
        token_embs=self.pos_emb(token_embs)
        tokens_causal_mask=torch.triu(torch.ones(token_ids.size()[1],token_ids.size()[1]),diagonal=1).type(torch.bool).to(token_ids.device)
        dec_out=self.decoder(tgt=token_embs,memory=enc_out,tgt_mask=tokens_causal_mask,tgt_key_padding_mask=token_pad_mask,memory_key_padding_mask=audio_pad_mask)
        final_out=self.prob_fc(dec_out)
        return final_out
    
    def forward(self,audio_features,audio_pad_mask,token_ids,token_pad_mask):
        enc_out=self.encode(audio_features,audio_pad_mask)
        final_out=self.decode(enc_out,audio_pad_mask,token_ids,token_pad_mask)
        return final_out

In [5]:
from process_data import load_metadata,load_sample,load_tokenizer

tokenizer=load_tokenizer()

train_metas=load_metadata('data/train.txt')
sample=load_sample(train_metas[0])

audio_features=sample['audio_features'].unsqueeze(0)
audio_pad_mask=torch.zeros(audio_features.size(0),audio_features.size(1)).bool()
token_ids=torch.tensor(sample['tokens'].ids[:-1],dtype=torch.long).unsqueeze(0)
token_pad_mask=torch.zeros_like(token_ids).bool()

model=TransformerASR(vocab_size=tokenizer.get_vocab_size())
dec_out=model(audio_features,audio_pad_mask,token_ids,token_pad_mask)
dec_out.shape



torch.Size([1, 13, 500])

数据集

In [6]:
class LRS2Dataset:
    def __init__(self,split='train'):
        self.metas=load_metadata(f'data/{split}.txt')
    
    def __len__(self):
        return len(self.metas)
    
    def __getitem__(self,idx):
        metaname=self.metas[idx]
        sample=load_sample(metaname)
        
        audio_features=sample['audio_features']
        token_ids=torch.tensor(sample['tokens'].ids,dtype=torch.long)
        return audio_features,token_ids

In [7]:
train_ds=LRS2Dataset(split='train')
audio_features,token_ids=train_ds[0]
print(f'audio_features:{audio_features.shape},token_ids:{token_ids.shape}')

audio_features:torch.Size([139, 80]),token_ids:torch.Size([14])


训练

In [8]:
LR=1e-4
EPOCHS=50
BATCH_SIZE=48
CHECKPOINT='checkpoint.pt'

In [9]:
import torch.utils

def collate_fn(batch):
    batch_audio_features=[]
    batch_token_ids=[]
    batch_token_pad_mask=[]
    max_audio_steps=0
    max_token_lens=0
    for sample in batch:
        batch_audio_features.append(sample[0])
        batch_token_ids.append(sample[1])
        max_audio_steps=max(max_audio_steps,len(sample[0]))
        max_token_lens=max(max_token_lens,len(sample[1]))
    # audio features padding zero
    batch_audio_pad_mask=[]
    for i,audio_features in enumerate(batch_audio_features):
        batch_audio_features[i]=torch.cat([audio_features,torch.zeros(max_audio_steps-len(audio_features),audio_features.shape[1])])
        batch_audio_pad_mask.append(torch.cat([torch.zeros(len(audio_features)),torch.ones(max_audio_steps-len(audio_features))]))
    # token ids padding [PAD]
    pad=tokenizer.token_to_id('[PAD]')
    for i,token_ids in enumerate(batch_token_ids):
        batch_token_ids[i]=torch.cat([token_ids,torch.full((max_token_lens-len(token_ids),),pad)])
        batch_token_pad_mask.append(torch.cat([torch.zeros(len(token_ids)),torch.ones(max_token_lens-len(token_ids))]))
    
    # decoder inputs & outputs
    batch_next_token_ids=[]
    for i,token_ids in enumerate(batch_token_ids):
        batch_next_token_ids.append(token_ids[1:])
        batch_token_ids[i]=token_ids[:-1]
        batch_token_pad_mask[i]=batch_token_pad_mask[i][:-1]
    return torch.stack(batch_audio_features,dim=0),\
        torch.stack(batch_audio_pad_mask,dim=0).bool(),\
        torch.stack(batch_token_ids,dim=0),\
        torch.stack(batch_token_pad_mask,dim=0).bool(),\
        torch.stack(batch_next_token_ids,dim=0)
        
train_ds=LRS2Dataset(split='train')
dataloader=torch.utils.data.DataLoader(dataset=train_ds,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_fn)
batch_audio_features,batch_audio_pad_mask,batch_token_ids,batch_token_pad_mask,batch_next_token_ids=next(iter(dataloader))
# 打印所有形状
print(f'''
batch_audio_features.shape: {batch_audio_features.shape} 
batch_audio_pad_mask.shape: {batch_audio_pad_mask.shape} 
batch_token_ids.shape: {batch_token_ids.shape} 
batch_token_pad_mask.shape: {batch_token_pad_mask.shape} 
batch_next_token_ids.shape: {batch_next_token_ids.shape}
''')


batch_audio_features.shape: torch.Size([48, 548, 80]) 
batch_audio_pad_mask.shape: torch.Size([48, 548]) 
batch_token_ids.shape: torch.Size([48, 37]) 
batch_token_pad_mask.shape: torch.Size([48, 37]) 
batch_next_token_ids.shape: torch.Size([48, 37])



In [10]:
import time
import os 
from torch.optim.lr_scheduler import LinearLR
import swanlab #pip install swanlab
#swanlab.login(api_key='your-api-key', save=True)

def train(use_swanlab=False):    
    train_ds=LRS2Dataset(split='train')
    dataloader=torch.utils.data.DataLoader(dataset=train_ds,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_fn)
    model=TransformerASR(vocab_size=tokenizer.get_vocab_size()).to(device)
    optimizer=torch.optim.Adam(model.parameters(),lr=LR)
    scheduler=LinearLR(
        optimizer,
        start_factor=1.0,       
        end_factor=0.1,
        total_iters=EPOCHS
    )
    loss_fn=torch.nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id('[PAD]'))
    if os.path.exists(CHECKPOINT):
        checkpoint=torch.load(CHECKPOINT)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if use_swanlab:
        swanlab.init(
            project='transformer-acr',
            experiment_name=f'experiment-{time.strftime("%Y%m%d-%H%M%S")}',
            description='Transformer ASR Training With LRS2 Dataset',
            config={
                'lr':LR,
                'architecture':str(model),
                'dataset':'LRS2',
                'epochs':EPOCHS,
                'batch_size':BATCH_SIZE,
            }
        )
    model.train()
    try:
        samples=0
        for epoch in range(EPOCHS):
            for batch_audio_features,batch_audio_pad_mask,batch_token_ids,batch_token_pad_mask,batch_next_token_ids in dataloader:
                batch_audio_features,batch_audio_pad_mask,batch_token_ids,batch_token_pad_mask,batch_next_token_ids=batch_audio_features.to(device),batch_audio_pad_mask.to(device),batch_token_ids.to(device),batch_token_pad_mask.to(device),batch_next_token_ids.to(device)
                probs=model(batch_audio_features,batch_audio_pad_mask,batch_token_ids,batch_token_pad_mask)
                probs=probs.view(-1,probs.shape[-1])
                loss=loss_fn(probs,batch_next_token_ids.flatten())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                samples+=len(batch_audio_features)
                if use_swanlab:
                    swanlab.log({'loss':loss.item(),'epoch':samples/len(train_ds),'samples':samples,'lr':optimizer.param_groups[0]['lr']})
                else:
                    print(f'loss: {loss.item()}, epoch: {samples/len(train_ds)}, samples: {samples}, lr: {optimizer.param_groups[0]["lr"]}')
            torch.save({'model_state_dict':model.state_dict(),'optimizer_state_dict':optimizer.state_dict()},CHECKPOINT)
            scheduler.step()
    finally: 
        if use_swanlab:
            swanlab.finish()

train(use_swanlab=True)



[1m[34mswanlab[0m[0m: swanlab version 0.6.4 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1mc:\Users\owen\Documents\VsCode\transformer-acr\swanlog\run-20250622_155011-ab81f83f[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mowenliang[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mexperiment-20250622-155010[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@owenliang/transformer-acr[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@owenliang/transformer-acr/runs/v82evkjj682vc5uju33jo[0m[0m


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@owenliang/transformer-acr[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@owenliang/transformer-acr/runs/v82evkjj682vc5uju33jo[0m[0m
                                                                                                    


推理

In [11]:
from process_data import decode
def asr(model,tokenizer,audio_features):
    model.eval()
    audio_features=audio_features.unsqueeze(0).to(device)
    audio_pad_mask=torch.zeros(audio_features.size(0),audio_features.size(1)).bool().to(device)
    enc_out=model.encode(audio_features,audio_pad_mask)
    token_ids_list=[tokenizer.token_to_id('[BOS]')]
    while True:
        token_ids=torch.tensor(token_ids_list,dtype=torch.long).unsqueeze(0).to(device)
        token_pad_mask=torch.zeros_like(token_ids).bool().to(device)
        dec_out=model.decode(enc_out,audio_pad_mask,token_ids,token_pad_mask)
        next_token_id=torch.argmax(dec_out[0,-1,:])
        if next_token_id==tokenizer.token_to_id('[EOS]'):
            break
        token_ids_list.append(next_token_id)
    return decode(tokenizer,token_ids_list)

In [28]:
from process_data import load_metadata,load_sample,load_tokenizer

test_metas=load_metadata('data/test.txt')
sample=load_sample(test_metas[10])
print(f'Path:{test_metas[10]}')

tokenizer=load_tokenizer()
model=TransformerASR(vocab_size=tokenizer.get_vocab_size()).to(device)
checkpoint=torch.load(CHECKPOINT)
model.load_state_dict(checkpoint['model_state_dict'])

text=asr(model,tokenizer,sample['audio_features'])
print(text)

Path:6331559613336179781/00038
AND IF YOU'RE ONE WITH FALL
