Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions examples/models/moshi/mimi/test_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import unittest
import urllib

import numpy as np
import requests
Expand All @@ -23,7 +24,10 @@
from torch.export import export, ExportedProgram
from torch.utils._pytree import tree_flatten

os.environ["https_proxy"] = "http://fwdproxy:8080"
proxies = {
"http": "http://fwdproxy:8080",
"https": "http://fwdproxy:8080",
}


def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
Expand All @@ -38,7 +42,12 @@ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:


def read_mp3_from_url(url):
response = requests.get(url)
try:
response = requests.get(url)
except:
# FB-only hack, need to use a forwarding proxy to get url
response = requests.get(url, proxies=proxies)

response.raise_for_status() # Ensure request is successful
audio_stream = io.BytesIO(response.content)
waveform, sample_rate = torchaudio.load(audio_stream, format="mp3")
Expand Down Expand Up @@ -68,7 +77,13 @@ def seed_all(seed):
seed_all(42424242)

if mimi_weight is None:
mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME)
try:
mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME)
except:
mimi_weight = hf_hub_download(
hf_repo, loaders.MIMI_NAME, proxies=proxies
)

cls.mimi = loaders.get_mimi(mimi_weight, device)
cls.device = device
cls.sample_pcm, cls.sample_sr = read_mp3_from_url(
Expand Down
Loading