In [2]:
!pip install f5_tts

Defaulting to user installation because normal site-packages is not writeable
Collecting f5_tts
  Downloading f5_tts-1.1.7-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: f5_tts
Successfully installed f5_tts-1.1.7


In [17]:
!pip install peft

Defaulting to user installation because normal site-packages is not writeable
Collecting peft
  Downloading peft-0.16.0-py3-none-any.whl (472 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.3/472.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: peft
Successfully installed peft-0.16.0


In [1]:
import torch
import argparse
from safetensors.torch import load_file
from f5_tts.model.utils import get_tokenizer
import warnings; warnings.filterwarnings("ignore")

In [2]:
def inspect_lora_weights(sd):
    # find all LoRA params in the after-finetuned checkpoint
    lora_items = {k: v for k,v in sd.items() if ".lora_" in k and isinstance(v, torch.Tensor)}
    if not lora_items:
        print("No LoRA weights found in this checkpoint.")
        return

    print("LoRA adapter parameter stats:\n")
    total_norm2 = 0.0
    total_elements = 0
    for name, tensor in sorted(lora_items.items()):
        t = tensor.cpu()
        mean_abs = t.abs().mean().item()
        max_abs  = t.abs().max().item()
        norm2    = t.norm().item()
        total_norm2 += norm2**2
        total_elements += t.numel()
        print(f"{name:60s}  shape={tuple(t.shape)}  mean|·|={mean_abs:.3e} max|·|={max_abs:.3e} ‖·‖₂={norm2:.3e}")

    overall_fro = total_norm2**0.5
    avg_norm    = overall_fro / (total_elements**0.5)
    print("\nSummary:")
    print(f"  Total LoRA parameters : {total_elements:,}")
    print(f"  Combined Frobenius ‖Δ‖₂ : {overall_fro:.3e}")
    print(f"  Avg per-param scale    : {avg_norm:.3e}")
    print()

In [2]:
# Path to your checkpoints
before_ckpt = "/work/users/r/p/rphadke/JSALT/ckpts/fisher_chunks_1K_LoRAv3.4/model_1000.pt"
after_ckpt = "/work/users/r/p/rphadke/JSALT/ckpts/fisher_chunks_1K_LoRAv3.4/model_9000.pt"
tok_path = "/work/users/r/p/rphadke/JSALT/fisher_chunks_0.1K_v2.1/vocab.txt"
vocab_char_map, vocab_size = get_tokenizer(tok_path, "custom")

# ID of your new speaker token (e.g. tokenizer.convert_tokens_to_ids("<spk>"))
spk_chg_token = "-"

    # 1) Load state dicts
sd_before = torch.load(before_ckpt,  map_location="cpu")["model_state_dict"]
print("\n == Prior Checkpoint ==")
# print(set(sd_before.keys()))
sd_after  = torch.load(after_ckpt,  map_location="cpu")["model_state_dict"]
sd_after = {k.replace("ema_model.", ""): v for k, v in sd_after.items()}
print("\n == New Checkpoint ==")
# print(set(sd_after.keys()))


 == Prior Checkpoint ==

 == New Checkpoint ==


In [13]:
from f5_tac.model.utils import list_str_to_idx

# 2) Extract embedding weights
#    adjust the key if your checkpoint nests it differently
Wb = sd_before["base_model.model.transformer.text_embed.text_embed.weight"]
Wa = sd_after["base_model.model.transformer.text_embed.text_embed.weight"]

ib = sd_before["base_model.model.transformer.input_embed.proj.weight"]
ia = sd_after["base_model.model.transformer.input_embed.proj.weight"]

# 3) Compute absolute difference
delta = (Wa - Wb).abs()

idelta = (ib - ia).abs()

spk_chg_token = "<utt>"

# 4) Metrics
speaker_chg_token_idx = list_str_to_idx([spk_chg_token], vocab_char_map=vocab_char_map)

max_overall   = delta.max().item()
max_speaker   = delta[speaker_chg_token_idx].max().item()

print("\n== Text Embedding Drift ==")
print(f"→ Max text embedding matrix change |Δ| overall:                  {max_overall:.6f}")
print(f"→ Max |Δ| on token {speaker_chg_token_idx}: {max_speaker:.6f}")

print("\n== Input Embedding Drift ==")
print(f"→ Mean input embedding weight change |Δ| overall:                  {idelta.max().item():.6f}")

print("== LoRA Weight Scales ==")
inspect_lora_weights(sd_after)


== Text Embedding Drift ==
→ Max text embedding matrix change |Δ| overall:                  0.000003
→ Max |Δ| on token tensor([[2545]]): 0.000000

== Input Embedding Drift ==
→ Mean input embedding weight change |Δ| overall:                  0.000051
== LoRA Weight Scales ==
LoRA adapter parameter stats:

base_model.model.transformer.text_embed.text_blocks.0.pwconv1.lora_A.default.weight  shape=(64, 512)  mean|·|=2.207e-02 max|·|=4.419e-02 ‖·‖₂=4.615e+00
base_model.model.transformer.text_embed.text_blocks.0.pwconv1.lora_B.default.weight  shape=(1024, 64)  mean|·|=4.530e-07 max|·|=4.475e-06 ‖·‖₂=1.568e-04
base_model.model.transformer.text_embed.text_blocks.0.pwconv2.lora_A.default.weight  shape=(64, 1024)  mean|·|=1.562e-02 max|·|=3.125e-02 ‖·‖₂=4.617e+00
base_model.model.transformer.text_embed.text_blocks.0.pwconv2.lora_B.default.weight  shape=(512, 64)  mean|·|=6.030e-07 max|·|=5.283e-06 ‖·‖₂=1.454e-04
base_model.model.transformer.text_embed.text_blocks.1.pwconv1.lora_A.default.weig

# Check Text Embedding Addition

In [3]:
import argparse
import os
import shutil
from importlib.resources import files
from functools import partial

import torch
from cached_path import cached_path

import yaml

# --- MODIFICATION: Import your new modules ---
from f5_tac.model.cfm import CFMWithTAC
from f5_tac.model.reccfm import CFMWithTACRecon
from f5_tac.model.backbones.dittac import DiTWithTAC
from f5_tac.model.trainer import Trainer
from f5_tac.configs.model_kwargs import lora_configv2
from f5_tts.model.utils import get_tokenizer

In [23]:
pretrain = "/work/users/r/p/rphadke/JSALT/ckpts/pretrained_model_1250000.safetensors"
local_pretrain_path = pretrain
print(f"Using pretrained model from: {local_pretrain_path}")

# --- 2. Setup Tokenizer ---
tokenizer_path = "/work/users/r/p/rphadke/JSALT/fisher_chunks_0.1K_v2.1/vocab.txt"
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path,"custom")
print(f"Loaded tokenizer with vocab size: {vocab_size}")

Using pretrained model from: /work/users/r/p/rphadke/JSALT/ckpts/pretrained_model_1250000.safetensors
Loaded tokenizer with vocab size: 2546


In [9]:
# --- 3. Define Model Architecture and Mel Spectrogram settings ---
# These should match the architecture of the pretrained model you are loading.
# mel_spec_kwargs = dict(
#     n_fft=1024, hop_length=256, win_length=1024,
#     n_mel_channels=100, target_sample_rate=24000, mel_spec_type="vocos",
# )


# dit_cfg = dict(
#     dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
# )

from f5_tac.configs.model_kwargs import mel_spec_kwargs, dit_cfg

# --- 4. Instantiate Your New Models ---
print("Instantiating F5-TAC models...")
transformer_backbone = DiTWithTAC(
    **dit_cfg,
    num_speakers=2, # Critical for TAC blocks
    text_num_embeds=vocab_size,
    mel_dim=mel_spec_kwargs["n_mel_channels"]
)

model = CFMWithTACRecon(
    transformer=transformer_backbone,
    mel_spec_kwargs=mel_spec_kwargs,
    vocab_char_map=vocab_char_map,

)

Instantiating F5-TAC models...


In [10]:
print("Loading pretrained weights...")
# 1) Load the raw checkpoint
if local_pretrain_path.endswith(".safetensors"):
    from safetensors.torch import load_file
    ckpt = load_file(local_pretrain_path, device="cpu")
else:
    ckpt = torch.load(local_pretrain_path, map_location="cpu")

# print(ckpt)

# 2) Unwrap any nesting
state = (
    ckpt.get("model_state_dict", 
    ckpt.get("ema_model_state_dict", ckpt))
)

# 3) Strip an ‘ema_model.’ prefix if it sneaked in
state = {k.replace("ema_model.", ""): v for k, v in state.items()}

old_embed_weight = state["transformer.text_embed.text_embed.weight"]
old_embed_weight

Loading pretrained weights...


tensor([[ 0.4294,  1.1332, -0.3012,  ...,  0.7869, -0.0200, -0.0887],
        [ 0.3353,  0.2027,  1.1162,  ...,  0.2878,  0.3035,  0.0241],
        [-0.1944, -0.5891, -0.3902,  ..., -0.8912, -0.2729,  1.2370],
        ...,
        [-0.9431,  0.1731, -0.1253,  ...,  0.5341, -1.6237, -0.4977],
        [ 0.3542,  0.4017,  0.2221,  ..., -0.7915,  0.5031,  0.7767],
        [ 0.4798, -0.1583,  0.2985,  ..., -0.0142, -1.1777, -0.1104]])

In [11]:
old_embed_weight.shape

torch.Size([2546, 512])

In [13]:
with torch.no_grad():
    new_embed_weights = torch.cat([
        old_embed_weight,                               # copy existing rows
        torch.randn(1, old_embed_weight.shape[1]) * 0.01  # random init new row
    ], dim=0)
new_embed_weights

tensor([[ 4.2939e-01,  1.1332e+00, -3.0124e-01,  ...,  7.8693e-01,
         -1.9993e-02, -8.8680e-02],
        [ 3.3533e-01,  2.0267e-01,  1.1162e+00,  ...,  2.8777e-01,
          3.0348e-01,  2.4130e-02],
        [-1.9440e-01, -5.8913e-01, -3.9018e-01,  ..., -8.9120e-01,
         -2.7291e-01,  1.2370e+00],
        ...,
        [ 3.5422e-01,  4.0167e-01,  2.2213e-01,  ..., -7.9148e-01,
          5.0305e-01,  7.7668e-01],
        [ 4.7975e-01, -1.5830e-01,  2.9851e-01,  ..., -1.4222e-02,
         -1.1777e+00, -1.1040e-01],
        [ 2.6985e-04, -1.6784e-02,  2.9970e-03,  ...,  1.5645e-03,
          6.0575e-03,  1.3508e-02]])

In [14]:
new_embed_weights.shape

torch.Size([2547, 512])

In [15]:
# Replace the weight in the state dict
state["transformer.text_embed.text_embed.weight"] = new_embed_weights

In [16]:
# 5) Finally load with strict=False to pick up whatever lines up
incompatible = model.load_state_dict(state, strict=False)


print("✔ loaded partial state:")
print("  • missing   (should only be tac module keys)   :", incompatible.missing_keys[:5], "…")
print("  • unexpected  :", incompatible.unexpected_keys[:5], "…")

✔ loaded partial state:
  • missing   (should only be tac module keys)   : ['transformer.transformer_blocks.0.tac.alpha', 'transformer.transformer_blocks.0.tac.transform_shared.0.weight', 'transformer.transformer_blocks.0.tac.transform_shared.0.bias', 'transformer.transformer_blocks.0.tac.transform_avg.0.weight', 'transformer.transformer_blocks.0.tac.transform_avg.0.bias'] …
  • unexpected  : ['initted', 'step'] …


In [18]:
from peft import LoraConfig, PeftModel, LoraModel, get_peft_model
model = get_peft_model(model, lora_configv2)
model.print_trainable_parameters()


for name, param in model.named_parameters():
    if "tac" in name or "text_embed.text_embed" in name or "3.dwconv" in name:
        param.requires_grad = True
        
    if "text_embed.text_embed" in name:
        print(name)

The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.


trainable params: 20,971,520 || all params: 773,507,706 || trainable%: 2.7112
base_model.model.transformer.text_embed.text_embed.weight


In [52]:
print(model)

PeftModel(
  (base_model): LoraModel(
    (model): PeftModel(
      (base_model): LoraModel(
        (model): CFMWithTACRecon(
          (mel_spec): MelSpec()
          (transformer): DiTWithTAC(
            (time_embed): TimestepEmbedding(
              (time_embed): SinusPositionEmbedding()
              (time_mlp): Sequential(
                (0): Linear(in_features=256, out_features=1024, bias=True)
                (1): SiLU()
                (2): Linear(in_features=1024, out_features=1024, bias=True)
              )
            )
            (text_embed): TextEmbedding(
              (text_embed): Embedding(2547, 512)
              (text_blocks): Sequential(
                (0): ConvNeXtV2Block(
                  (dwconv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), padding=(3,), groups=512)
                  (norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
                  (pwconv1): lora.Linear(
                    (base_layer): Linear(in_features=512, out_feat

# Text to Index New Function Test

In [6]:
for key in vocab_char_map.keys():
    if "utt" in key:
        print(key)

<utt>


In [7]:
from torch.nn.utils.rnn import pad_sequence
def list_str_to_idx(
    text: list[str] | list[list[str]],
    vocab_char_map: dict[str, int],  # {char: idx}
    padding_value=-1,
):
    list_idx_tensors = []
    for t in text:
        idxs = []
        i = 0
        while i < len(t):
            if t[i : i + 5] == "<utt>":  # Detect "<utt>"
                idxs.append(vocab_char_map.get("<utt>", 0))
                i += 5  # Skip over "<utt>"
            elif t[i : i + 5] == "<sil>":  # Detect "<utt>"
                idxs.append(-1)
                i += 5  # Skip over "<utt>"
            else:
                idxs.append(vocab_char_map.get(t[i], 0))
                i += 1
        list_idx_tensors.append(torch.tensor(idxs))
    text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
    return text

text = "and i generally<sil><sil> prefer <utt>  eating at home <utt>  hello andy <utt>  how are you <utt>  good <utt>  do you have any idea what's going on <utt> "
list_str_to_idx([text], vocab_char_map)

tensor([[  62,  765,  250,    0,  507,    0,  369,  325,  765,  325,  940,   62,
          615,  615, 1236,   -1,   -1,    0,  834,  940,  325,  337,  325,  940,
            0, 2545,    0,    0,  325,   62, 1082,  507,  765,  369,    0,   62,
         1082,    0,  441,  827,  704,  325,    0, 2545,    0,    0,  441,  325,
          615,  615,  827,    0,   62,  765,  250, 1236,    0, 2545,    0,    0,
          441,  827, 1149,    0,   62,  940,  325,    0, 1236,  827, 1147,    0,
         2545,    0,    0,  369,  827,  827,  250,    0, 2545,    0,    0,  250,
          827,    0, 1236,  827, 1147,    0,  441,   62, 1148,  325,    0,   62,
          765, 1236,    0,  507,  250,  325,   62,    0, 1149,  441,   62, 1082,
            7,  973,    0,  369,  827,  507,  765,  369,    0,  827,  765,    0,
         2545,    0]])

In [25]:
from f5_tac.model.utils import list_str_to_idx

# from f5_tts.model.utils import list_str_to_idx

In [24]:
from f5_tac.model.dataset import load_conversation_dataset, conversation_collate_fn
from torch.utils.data import DataLoader

dataset_path = "/work/users/r/p/rphadke/JSALT/fisher_chunks_0.1K_v2.1"

print("Loading dataset...")
train_dataset = load_conversation_dataset(
    dataset_path=dataset_path,
    mel_spec_kwargs=mel_spec_kwargs
)

train_dataloader = DataLoader(
    train_dataset,
    collate_fn=conversation_collate_fn,
    num_workers=1,
    pin_memory=True,
    persistent_workers=True,
    batch_size=1,
    shuffle=True
)

for batch in train_dataloader:
    print(batch["text_A"])
    print(list_str_to_idx(batch["text_A"], vocab_char_map))
    break

Loading dataset...
Loading conversation dataset...
["e- e- everything could (( )) to something i think even gossiping <utt> they're (( enjoyable )) <utt> yeah but anything excessive i think ah <utt> e- it's something <utt> that <utt> challenges us <utt> at cer- certain points to rethink <utt> what <utt>"]
tensor([[ 325,   13,    0,  325,   13,    0,  325, 1148,  325,  940, 1236, 1082,
          441,  507,  765,  369,    0,  143,  827, 1147,  615,  250,    0,    8,
            8,    0,    9,    9,    0, 1082,  827,    0,  973,  827,  704,  325,
         1082,  441,  507,  765,  369,    0,  507,    0, 1082,  441,  507,  765,
          560,    0,  325, 1148,  325,  765,    0,  369,  827,  973,  973,  507,
          834,  507,  765,  369,    0, 2545,    0, 1082,  441,  325, 1236,    7,
          940,  325,    0,    8,    8,    0,  325,  765,  508,  827, 1236,   62,
           78,  615,  325,    0,    9,    9,    0, 2545,    0, 1236,  325,   62,
          441,    0,   78, 1147, 1082,    0, 

In [26]:
from f5_tac.model.dataset import load_conversation_dataset, conversation_collate_fn
from torch.utils.data import DataLoader

dataset_path = "/work/users/r/p/rphadke/JSALT/fisher_chunks_0.1K_v2.1"

print("Loading dataset...")
train_dataset = load_conversation_dataset(
    dataset_path=dataset_path,
    mel_spec_kwargs=mel_spec_kwargs
)

train_dataloader = DataLoader(
    train_dataset,
    collate_fn=conversation_collate_fn,
    num_workers=1,
    pin_memory=True,
    persistent_workers=True,
    batch_size=1,
    shuffle=True
)

for batch in train_dataloader:
    print(batch["text_A"])
    print(list_str_to_idx(batch["text_A"], vocab_char_map))
    break

Loading dataset...
Loading conversation dataset...
["yeah i think for me it would be thanksgiving um i'm jewish so i don't have the christmas thing to compare with <utt> and um <utt> [mn] so it's really i i do that and i do um passover with my family but that is more <utt> (( uh )) <utt> can get boring sometimes [laughter] and thank- i just really like to uh <utt> [mn] get together with my family so i like that <utt>"]
tensor([[1236,  325,   62,  441,    0,  507,    0, 1082,  441,  507,  765,  560,
            0,  337,  827,  940,    0,  704,  325,    0,  507, 1082,    0, 1149,
          827, 1147,  615,  250,    0,   78,  325,    0, 1082,  441,   62,  765,
          560,  973,  369,  507, 1148,  507,  765,  369,    0, 1147,  704,    0,
          507,    7,  704,    0,  508,  325, 1149,  507,  973,  441,    0,  973,
          827,    0,  507,    0,  250,  827,  765,    7, 1082,    0,  441,   62,
         1148,  325,    0, 1082,  441,  325,    0,  143,  441,  940,  507,  973,
         1