diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 881f8c9371c..c8d9921ed67 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -2,6 +2,7 @@ import os import random import unittest +import urllib import numpy as np import requests @@ -23,7 +24,10 @@ from torch.export import export, ExportedProgram from torch.utils._pytree import tree_flatten -os.environ["https_proxy"] = "http://fwdproxy:8080" +proxies = { + "http": "http://fwdproxy:8080", + "https": "http://fwdproxy:8080", +} def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float: @@ -38,7 +42,12 @@ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float: def read_mp3_from_url(url): - response = requests.get(url) + try: + response = requests.get(url) + except: + # FB-only hack, need to use a forwarding proxy to get url + response = requests.get(url, proxies=proxies) + response.raise_for_status() # Ensure request is successful audio_stream = io.BytesIO(response.content) waveform, sample_rate = torchaudio.load(audio_stream, format="mp3") @@ -68,7 +77,13 @@ def seed_all(seed): seed_all(42424242) if mimi_weight is None: - mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) + try: + mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) + except: + mimi_weight = hf_hub_download( + hf_repo, loaders.MIMI_NAME, proxies=proxies + ) + cls.mimi = loaders.get_mimi(mimi_weight, device) cls.device = device cls.sample_pcm, cls.sample_sr = read_mp3_from_url(