-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ETA of training code publication #1
Comments
It will likely be July or August. It depends on how well my other project goes and how many GPU resources I will get in my lab. The GPUs are very busy during the summer as students are working on their projects full time, and some other projects will be the priority before I can work on cleaning and testing the code. If you are interested, you can start with the StyleTTS w/ PL-BERT code and try to code it yourself. I believe the most important part is the adversarial training and style diffusion, so I will provide the code snippets here. The code is not tested nor cleaned, but it was copied from the Jupyter notebook I ran the experiment with. Style diffusion: (you will need from audio_diffusion_pytorch.modules import *
class Transformer1d(nn.Module):
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
use_context_time: bool = True,
use_rel_pos: bool = False,
context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512,
):
super().__init__()
self.blocks = nn.ModuleList(
[
TransformerBlock(
features=channels + context_embedding_features,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
Conv1d(
in_channels=channels + context_embedding_features,
out_channels=channels,
kernel_size=1,
),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
assert exists(context_mapping_features)
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
self.fixed_embedding = FixedEmbedding(
max_length=embedding_max_length, features=context_embedding_features
)
def get_mapping(
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
) -> Optional[Tensor]:
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
assert_message = "use_context_time=True but no time features provided"
assert exists(time), assert_message
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
assert_message = "context_features exists but no features provided"
assert exists(features), assert_message
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor,
time: Tensor,
embedding_mask_proba: float = 0.0,
embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None,
embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
# Compute both normal and fixed embedding outputs
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
# Scale conditional output using classifier-free guidance
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)
return x
transformer = Transformer1d(
num_layers=3,
channels=256,
num_heads=8,
head_features=64,
multiplier=2,
context_embedding_features=768,
)
from audio_diffusion_pytorch import AudioDiffusionConditional, DiffusionSampler
diffusion = AudioDiffusionConditional(
in_channels=1,
embedding_max_length=512,
embedding_features=768,
embedding_mask_proba=0.1, # Conditional dropout of batch elements,
multipliers=[1, 2],
channels=256,
patch_size=16,
factors=[2],
attentions=[0, 1],
num_blocks=[2]
)
diffusion.diffusion.net = transformer
diffusion.unet = transformer
diffusion.diffusion = KDiffusion(
net=diffusion.unet,
sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
sigma_data=0.2,
dynamic_threshold=0.0
)
sampler = DiffusionSampler(
model.diffusion.diffusion,
num_steps=5,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
clamp=False
) WavLM discriminator: class WavLMLoss(torch.nn.Module):
def __init__(self, model, wd):
"""Initilize spectral convergence loss module."""
super(WavLMLoss, self).__init__()
self.wavlm = WavLMModel.from_pretrained(model)
self.wd = wd
self.resample = torchaudio.transforms.Resample(24000, 16000)
def wd_forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
y_d_rs = self.wd(y_embeddings)
y_d_gs = self.wd(y_rec_embeddings)
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
loss_gen_f = torch.mean((1-y_df_hat_g)**2)
loss_rel = 0
loss_gen_all = loss_gen_f + loss_rel
return loss_gen_all
def wd_discriminator(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
y_d_rs = self.wd(y_embeddings)
y_d_gs = self.wd(y_rec_embeddings)
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
r_loss = torch.mean((1-y_df_hat_r)**2)
g_loss = torch.mean((y_df_hat_g)**2)
loss_disc_f = r_loss + g_loss
loss_rel = 0
d_loss = loss_disc_f + loss_rel
return d_loss.mean()
wl = WavLMLoss('microsoft/wavlm-base-plus', model.wd).to('cuda') Adversarial training run: for i, batch in enumerate(train_dataloader):
waves = batch[0]
batch = [b.to(device) for b in batch[1:]]
texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, labels = batch
# ... joint training code omitted
if np.random.rand() < 0.5:
use_ind = True
else:
use_ind = False
if use_ind:
ref_lengths = input_lengths
ref_texts = texts
text_mask = length_to_mask(ref_lengths).to(texts.device)
bert_dur = model.bert(ref_texts, attention_mask=(~text_mask).int()).last_hidden_state
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
if use_ind and np.random.rand() < 0.5:
s_preds = s_trg
else:
num_steps = np.random.randint(3, 5)
s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to('cuda'),
embedding=bert_dur,
embedding_scale=1,
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
s_dur = s_preds[:, 128:]
s = s_preds[:, :128]
d, _ = model.predictor(d_en, s_dur,
ref_lengths,
torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to('cuda'),
text_mask)
bib = 0
output_lengths = []
attn_preds = []
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), ref_lengths):
_s2s_pred_org = _s2s_pred[:_text_length, :]
_s2s_pred = torch.sigmoid(_s2s_pred_org)
_dur_pred = _s2s_pred.sum(axis=-1)
_text_input = _text_input[:_text_length].long()
l = int(torch.round(_s2s_pred.sum()).item())
t = torch.arange(0, l).expand(l)
t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to('cuda')
loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
sig = 1.5
h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (sig)**2)
out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
h.unsqueeze(1),
padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
attn_preds.append(F.softmax(out.squeeze(), dim=0))
output_lengths.append(l)
max_len = max(output_lengths)
with torch.no_grad():
t_en = model.text_encoder(ref_texts, ref_lengths, text_mask)
s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to('cuda')
for bib in range(len(output_lengths)):
s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
asr_pred = t_en @ s2s_attn
_, p_pred = model.predictor(d_en, s_dur,
ref_lengths,
s2s_attn,
text_mask)
mel_len = max(int(min(output_lengths) / 2 - 1), 200)
mel_len = min(mel_len, 250)
en = []
p_en = []
sp = []
l = []
F0_fakes = []
N_fakes = []
for bib in range(len(output_lengths)):
mel_length_pred = output_lengths[bib]
mel_length_gt = int(mel_input_length[bib].item() / 2)
if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
continue
sp.append(s_preds[bib])
random_start = np.random.randint(0, mel_length_pred - mel_len)
en.append(asr_pred[bib, :, random_start:random_start+mel_len])
p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
l.append(labels[bib])
if len(sp) <= 1:
continue
sp = torch.stack(sp)
en = torch.stack(en)
p_en = torch.stack(p_en)
labels = torch.stack(l)
F0_fake, N_fake = model.predictor.F0Ntrain(p_en, sp[:, 128:])
y_pred = model.decoder(en, F0_fake, N_fake, sp[:, :128])
wav = y_rec_gt_pred
optimizer.zero_grad()
d_loss = wl.wd_discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
d_loss.backward()
optimizer.step('wd')
# generator loss
optimizer.zero_grad()
loss_gen_lm = wl.wd_forward(wav.squeeze(), y_pred.squeeze())
loss_gen_lm = loss_gen_lm.mean()
loss_gen_lm.backward(retain_graph=True)
total_norm = {}
for key in model.keys():
total_norm[key] = 0
parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
for p in parameters:
param_norm = p.grad.detach().data.norm(2)
total_norm[key] += param_norm.item() ** 2
total_norm[key] = total_norm[key] ** 0.5
if total_norm['predictor'] > 20:
for key in model.keys():
for p in model[key].parameters():
if p.grad is not None:
p.grad *= 2 * (1 / total_norm['predictor'])
for p in model.predictor.duration_proj.parameters():
if p.grad is not None:
p.grad *= 1e-2
for p in model.predictor.lstm.parameters():
if p.grad is not None:
p.grad *= 1e-2
for p in model.diffusion.parameters():
if p.grad is not None:
p.grad *= 1e-2
optimizer.step('bert_encoder')
optimizer.step('bert')
optimizer.step('predictor')
optimizer.step('diffusion')
optimizer.step('style_encoder')
optimizer.step('predictor_encoder') # this is the prosodic style encoder, will rename it later in cleaned-up code
optimizer.step('decoder') I will leave this issue open in case someone else is interested in the implementation. |
Thank you for such a quick response! |
Hello, thank you for sharing the code. Could also share the duration predictor part as well? |
@talipturkmen It is a very simple change:
|
Any updates? |
@spolezhaev Sorry for keeping you waiting. I'm now halfway done with my other projects and have started working on code cleaning. I will make sure the code is available by the end of this month. |
Just wanted to say, amazing work. This is reaching almost tortoise-tts level quality but super fast inferencing of styletts. Cant wait to try this model out and maybe even finetuning it! |
Looking forward to it! |
@nivibilla not sure if you are referring to this one, but perceptually speaking it doesn't even sound better than VITS with language models (BERT), so do you remind explaining why people are so excited about it, especially given its insanely slow inference speed? Is it because it was trained on millions of hours of speech so people like its zero-shot speaker adaptation ability? |
No because one can finetune it too, on any voice, god forbid if it had low VRAM usage and some genuine ability to up the inference speed, I doubt there was any need to look for any other low-resource human-level speech synthesis model |
@yl4579 the original author removed finetuning code and nerfed it on purpose. However there are other branches that people have made. And even personally when I've tested finetuning it can capture the style and nuances in the voice really well. Out of the box it isn't that good. But finetuning on just an hour of data or so yields the most natural sounding tts. |
@nivibilla @exllama-fan I think that makes sense when the base model is large enough and pre-trained on enormous datasets, but unfortunately I don't believe you can fine-tune StyleTTS 2 to get similar performance especially with only one hour of data, because the biggest model I have so far was trained on only 585 hours of data (LibriTTS-R), incomparable to tortoise TTS trained on millions of hours. I simply don't have that huge amount of data (I believe tortoise TTS was trained by someone from OpenAI?) I do plan to train a model on Multilingual LibriSpeech with speech restoration like those used in LibriTTS-R, but our lab has other priorities for GPUs so I'm not sure when I would have time for this. |
@yl4579 I am planning on writing a multilingual extension for this model myself. In addition to adding a language embedding, from what I can tell, both the WavLM model as well as the PL-BERT model will have to be replaced with multilingual versions? I have access to some GPU resources for a few weeks, happy to train a model and share back results. |
For those who are waiting for the code, I apologize for the delay, but I'm having difficulties reproducing the results I got from Jupyter notebooks. After a few weeks of code cleaning, I found there were substantial performance differences between models trained with the original notebooks and with the cleaned code. I'm still investigating the causes here, so the code release will be delayed. If anyone is interested in reproducing the results with the notebooks I have and helping me clean the code, please email me at yl4579@columbia.edu and I'm happy to provide the Jupyter notebooks and the cleaned code I used for the experiments. |
just shot you a note from my gmail @yl4579 - happy to help clean + retrain |
@yl4579 +1, sent you an email, happy to help clean the code and reproduce the results. |
+1,happy to help clean the code and reproduce the results. |
@WendongGan I didn't get your email. Please email me at yl4579@columbia.edu. |
@yl4579 I'm sorry! I just sent it, thank you for checking. |
@yl4579 +1, just sent you an email. Happy to help! |
Is there any progress with the code-cleaning? Any further update on code-release? |
@lovebeatz So far only one person has confirmed that stage I (acoustic pretraining) code is probably fine, but nobody has reported any success in fixing the second stage training (joint training) that shows discrepancy in F0 and norm loss between the Jupyter notebook and the cleaned code. I probably won’t have time to work on this until end of this month, but hopefully someone could get it fixed soon. |
This issue is inexplicably weird. I’m not sure if I should release the problematic code to the public anyway and mark it as WIP or just wait for volunteers who have emailed me to work on it a little bit longer. |
So the notebook code works well? Do you have the pre-trained model ready? Also, tell me about the inference code. |
The inference code seems to work. At least I didn’t notice any clear degradation in quality of synthesized speech between the cleaned code and the notebook I used to run the experiment. I couldn’t find the exact checkpoints I used for experiments in the paper but I do have the trained checkpoints as the reference. I can share that with you, though I’m not sure how useful it really is. You can email me for the inference code if you need that. |
I am looking to fine-tune using a notebook and infer using the cleaned code. Also, I want to know for every new voice a separate fine-tuned model will be required, right? |
Unfortunately, the cleaned code still doesn't work at this time because it produces higher losses with worse quality. The notebook is uncleaned and uncommented, so I don't think you can use it to fine-tune anything at this point unless you know how to read the code and modify it for your own purpose. As for voices, I assume you mean speakers, so no, if you have multiple speakers you want to fine-tune with, you only need one model for as many speakers as you want. This is also the exact point of zero-shot speaker adaptation. |
But I read somewhere as you mentioned that for inference, styleTTS2 won't require any speaker reference unlike StyleTTS Also, given what was seen in styletts, does this new version offer sentence breaks in speech, like styletts won't pause at a full stop (if there's a sentence afterwards) |
@lovebeatz StyleTTS 2 doesn't require any reference for single speaker models, but it still needs a reference from the target speaker for multispeaker models because it needs to know which speaker you are about to synthesize. If your goal is to train a single speaker model like on the LJSpeech dataset, you don't need a reference. As for the pauses, yes, StyleTTS 2 does have sentence breaks. I just synthesized your sentences above and confirmed it can (and surprisingly StyleTTS w/ PL-BERT can't even do that for some reason): StyleTTS w/ PL-BERT: https://drive.google.com/file/d/1llMmllk9QyGYXBqsbRVjKzPQxcY7XQvW/view?usp=sharing |
@yl4579 , +1, wrote email message, can I help? |
I think if it doesn't affect the inference or the quality of the model, it shall be released, if you are looking for an opinion |
It's less efficient in terms of RAM (also the speed because you have to reduce the batch size now), and you can't do mixed precision easily this way either, but it's indeed just a matter of engineering so it can be dealt with by engineering people more fluent in programming than me. So, I'm thinking of releasing the DP version first if I can't fix the code by the end of the month and see if anyone else is interested in fixing the DDP version. |
Can you share the code of ddp training? I'm training with DDP with but super slow |
@primepake I will make a new repo with the broken DDP code if I can't fix it by the end of this month so that other people can work on it later. |
expect |
Unfortunately whenever I turn DP into DDP the F0 loss is consistently higher, which is very weird. I probably have to release the DP code first and see if any expert in DDP can fix the code later. |
Thanks for the effort guys |
Hi @yl4579, thank you for your work! I would like to reproduce the results from the paper and can also try to help with DDP issue. May I ask you to add me to the cleaning repo please? |
@danielmsu The code with DP is almost done and I’ll probably push it in a couple of days. I will make another public repo with not working DDP code for those with expertise to give a hand. |
@danielmsu Actually, the broken DDP code doesn't need a separate repo. I just opened a new issue #7 for this problem and copypasted the code there. The code can be tested under this repo directly. |
@yl4579 Thank you, I will check it out. |
@danielmsu I haven't encountered this problem. I have generated with the same text, and it works totally fine for me. Have you run the entire demo and does every single audio you generated sound like this? If so, it sounds like some dependencies might be messed up. It is totally normal if there is some small variation because the model is stochastic in nature, but the quality difference shouldn't be this big. |
@danielmsu I've created a Colab notebook that you can try here: https://colab.research.google.com/drive/1k5OqSp8a-x-27xlaWr2kZh_9F9aBh39K. I have tested it and it works totally fine. |
hey @yl4579 @danielmsu, I ran into the same issue with some high pitched noise in the background. After checking dependencies, new installs etc, it turns out it is related to the type of the GPU you use. I had a kind of old GPU (Quadro P5000) but when I switched to a new machine it was fine. Also, in the old machine if I use device='cpu' it also works fine. Not sure why this happens, maybe something related to the way older GPUs perform some operations or handle float representations? However, I found another difference. For the example: "Maltby and Co. would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses." in your audio samples (https://styletts2.github.io/). Co. is pronounced 'company' while in the notebook (and in my local machine) it is pronounced "co" example from the notebook: https://vocaroo.com/1iCB2q0HOqLh |
@teopapad92 You actually need this text cleaner (https://github.com/jaywalnut310/vits/blob/main/text/cleaners.py) to make “co.” pronounced as “company”. I used it for the paper and demo page samples but didn’t include it in the inference notebook, though I may add it later to be consistent with the demo page audios. However, it is just a matter of phonemization so people can do whatever they want. The training data was already phonemized this way as I took it directly from VITS. As for the high pitch distortion, can you make sure it is not a problem of dependencies but GPU? Have you tested with the exact same environment and after changing GPU the distortion disappears? I have tried it on GPUs as old as NVidia 780 and it still works. I think it’s a great idea to open a new issue for this problem. |
@yl4579 thank you for the answer and colab notebook, I switched to cpu as suggested by @teopapad92 and can confirm that the issue is gone |
@danielmsu This is so weird, can you open a new issue so other people can have a reference? |
Might be some library or framework that’s different, or maybe how different
precisions are being handled by the GPU?
…On Wed, Oct 4, 2023 at 12:15 PM Aaron (Yinghao) Li ***@***.***> wrote:
@danielmsu <https://github.com/danielmsu> This is so weird, can you open
a new issue so other people can have a reference?
—
Reply to this email directly, view it on GitHub
<#1 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAH2BVJHRFNXFUFIDVZAJNDX5WDRBAVCNFSM6AAAAAAZHSO332VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONBXGIZTCNRVGE>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Tested one more time inside of a docker container (with |
I have tested the current code and I was able to reproduce the models with similar quality to those used in the paper and demo. So I think issue is now complete. Please open new issues if there are more problems in the current code. |
@nivibilla I've pushed the finetuning script and I tried it myself with one hour of data on LJSpeech using pre-trained LibriTTS model. It sounds better than TortoiseTTS for sure (both quality and speaker similarity), but it is still worse than models trained from scratch with full data (24 hours of audio). The quality is still better than VITS and JETS and close to NaturalSpeech, so I think it is good enough with one hour of data. |
@yl4579 thanks so much. Will try it out when I get the time. |
@yl4579 Can you share the minimum required audio length for the new speaker and how much time it takes to fine tune? |
@primepake It depends on the quality you want to achieve. The more data the better, but I tried to finetune it with 10 minutes of audio and it still works, and the similarity is much better, though I wouldn’t say the naturalness is better. |
@yl4579 Sorry for hijacking the conversation, but do I understand correctly that after fine-tuning we still need to provide a reference audio, but results are much more similar than with zero-short voice cloning? |
@danielmsu Yes, but if it’s a single speaker dataset you are finetuning at the reference can be arbitrary and doesn’t effect the speech synthesized. You can also change the multispeaker flag to false and does not load the diffusion model when finetuning if you know your new dataset has only one speaker. |
select free port for webui by gradio
Thank you for your work! Is there any ETA on when the training and inference code will become available?
The text was updated successfully, but these errors were encountered: