In [36]:
import sys
sys.path.append("./src")
from models.codec import SwinAudioCodec

import yaml, torch
device = "cpu"

with open("./src/configs/residual_18k.yml", 'r') as f:
    config = yaml.safe_load(f)

model = SwinAudioCodec(**config["model"]).to(device)

Audio Codec 18.0kbps Initialized
Codec Causality: False
Apply Residual-Based Cross Attention Fusion Net for Swin Codec | Type: None
Quantization Vis: 
     EMA: False CosineSimilarity: True
     Freq dims:  [2, 2, 4, 8, 16, 32]
     Channel(hidden) dims:  [384, 384, 192, 144, 96, 72]
     Merged dims:  [768, 768, 768, 1152, 1536, 2304]
     GroupVQ proj dims:  [768, 768, 768, 1152, 1536, 2304]
     GroupVQ dims (for each):  [256, 256, 256, 384, 512, 768]
     Mapped Codebook dims (for each):  [8, 8, 8, 8, 8, 8]
Pre-swin Layer: swin_depth=2 swin_hidden=45 heads=3 down=False
Layer[0]: swin_depth=2 swin_hidden=45 heads=3 down=True
Layer[1]: swin_depth=2 swin_hidden=72 heads=6 down=True
Layer[2]: swin_depth=2 swin_hidden=96 heads=12 down=True
Layer[3]: swin_depth=2 swin_hidden=144 heads=24 down=True
Layer[4]: swin_depth=2 swin_hidden=192 heads=24 down=True

Layer[0]: swin_depth=2 swin_hidden=384 heads=24 up=True
Layer[1]: swin_depth=2 swin_hidden=192 heads=24 up=True
Layer[2]: swin_depth=2

In [37]:
trainable_params = sum(
	p.numel() for p in model.parameters() if p.requires_grad
)
trainable_params

8533829

## Evaluate Latency

In [38]:
import time, glob, torchaudio

eval_pth = "/Users/tracy/Desktop/eval_instances"
eval_files = glob.glob(f"{eval_pth}/*.wav")
eval_audios = [
    (torchaudio.load(f)[0]).to(device) for f in eval_files
]

eval_audios = eval_audios[:5]

model = model.to(device)
def compress(model, n_streams):
    encoded = []
    for d in eval_audios:

        multi_codes, _ = model.encode(d, num_streams=n_streams)

        encoded.append(multi_codes)

    return encoded


def recover(model, encoded):
    for encoded_d in encoded:
        x_recon = model.decode(encoded_d, enc_feat_size=(2,1000))
    return


In [39]:
start = time.time()
encoded = compress(model, n_streams=6)
end = time.time()
print(f"Ours Compress Time ({len(eval_audios)} 10sec audio) on {device}: ", end - start)

start = time.time()
recover(model, encoded)
end = time.time()
print(f"Ours Recover Time ({len(eval_audios)} 10sec audio) on {device}: ", end - start)

Ours Compress Time (5 10sec audio) on cpu:  6.510828971862793
Ours Recover Time (5 10sec audio) on cpu:  4.136021852493286


## Count VQ Utility

In [None]:
import json
usage_stats = json.load(open("./assets/results/swin-18k-residual/vq_stats/train/usage.json", "r"))

In [None]:
usage_stats

{'stream_0_group_1': 96.875,
 'stream_0_group_2': 57.71484375,
 'stream_0_group_3': 100.0,
 'stream_0_group_4': 86.23046875,
 'stream_0_group_5': 48.6328125,
 'stream_0_group_6': 100.0,
 'stream_1_group_1': 98.73046875,
 'stream_1_group_2': 43.84765625,
 'stream_1_group_3': 100.0,
 'stream_1_group_4': 96.19140625,
 'stream_1_group_5': 76.66015625,
 'stream_1_group_6': 100.0,
 'stream_2_group_1': 87.40234375,
 'stream_2_group_2': 34.27734375,
 'stream_2_group_3': 91.2109375,
 'stream_2_group_4': 20.3125,
 'stream_2_group_5': 60.44921875,
 'stream_2_group_6': 98.046875,
 'stream_3_group_1': 99.21875,
 'stream_3_group_2': 84.08203125,
 'stream_3_group_3': 100.0,
 'stream_3_group_4': 92.1875,
 'stream_3_group_5': 91.50390625,
 'stream_3_group_6': 89.453125,
 'stream_4_group_1': 85.3515625,
 'stream_4_group_2': 91.796875,
 'stream_4_group_3': 95.703125,
 'stream_4_group_4': 99.4140625,
 'stream_4_group_5': 86.62109375,
 'stream_4_group_6': 92.28515625,
 'stream_5_group_1': 31.15234375,
 'st

In [None]:
total = len(usage_stats) * 100
used_per = 0
for key, val in usage_stats.items():
    used_per += val

print(f"Total Used Percentage is {used_per}/{total} = {used_per/total}")
print(f"Effective Total Bitrate is {18*used_per/total:.2f}kbps/18kbps")

Total Used Percentage is 2806.73828125/3600 = 0.7796495225694444
Effective Total Bitrate is 14.03kbps/18kbps


In [None]:
effective_bitrates = {}

for i, (key, val) in enumerate(usage_stats.items(), start=0):
    effective_bitrates[key] = .5 * val * .01

cum = 0
for i in range(6):
    bits = sum([val for key, val in effective_bitrates.items() if key.startswith(f"stream_{i}")])
    cum += bits
    print(f"effective_bitrates_stream_{i}:", bits, f"effective_bitrates_{3*(i+1)}kbps:", cum)


effective_bitrates_stream_0: 2.447265625 effective_bitrates_3kbps: 2.447265625
effective_bitrates_stream_1: 2.5771484375 effective_bitrates_6kbps: 5.0244140625
effective_bitrates_stream_2: 1.95849609375 effective_bitrates_9kbps: 6.98291015625
effective_bitrates_stream_3: 2.7822265625 effective_bitrates_12kbps: 9.76513671875
effective_bitrates_stream_4: 2.755859375 effective_bitrates_15kbps: 12.52099609375
effective_bitrates_stream_5: 1.5126953125 effective_bitrates_18kbps: 14.03369140625


In [7]:
1 == True

True

In [1]:
import sys
sys.path.append("./src")
from models.codec import SwinAudioCodec

import yaml, torch
device = "cpu"

with open("./src/configs/residual_9k.yml", 'r') as f:
    config = yaml.safe_load(f)

model = SwinAudioCodec(**config["model"]).to(device)

Audio Codec 9.0kbps Initialized
Codec Causality: False
Apply Residual-Based Cross Attention Fusion Net for Swin Codec | Type: None
Quantization Vis: 
     EMA: False CosineSimilarity: True
     Freq dims:  [2, 2, 4, 8, 16, 32]
     Channel(hidden) dims:  [384, 384, 192, 144, 96, 72]
     Merged dims:  [768, 768, 768, 1152, 1536, 2304]
     GroupVQ proj dims:  [768, 768, 768, 1152, 1536, 2304]
     GroupVQ dims (for each):  [512, 512, 512, 768, 1024, 1536]
     Mapped Codebook dims (for each):  [12, 12, 12, 12, 12, 12]
Pre-swin Layer: swin_depth=2 swin_hidden=45 heads=3 down=False
Layer[0]: swin_depth=2 swin_hidden=45 heads=3 down=True
Layer[1]: swin_depth=2 swin_hidden=72 heads=6 down=True
Layer[2]: swin_depth=2 swin_hidden=96 heads=12 down=True
Layer[3]: swin_depth=2 swin_hidden=144 heads=24 down=True
Layer[4]: swin_depth=2 swin_hidden=192 heads=24 down=True

Layer[0]: swin_depth=2 swin_hidden=384 heads=24 up=True
Layer[1]: swin_depth=2 swin_hidden=192 heads=24 up=True
Layer[2]: swin_

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [2]:
model.load_state_dict(torch.load("/Users/tracy/Desktop/swin-9k-residual-gan-250k/best.pt", map_location=device)["model_state_dict"])

<All keys matched successfully>

In [3]:
import torchaudio
x, _ = torchaudio.load("/Users/tracy/Desktop/Audio_Codec/swin-debug-vis/test/mandarin_instance1.wav")

In [4]:
codes, _ = model.encode(x=x, num_streams=6)
print(len(codes), codes[0].shape)

recon_audio = model.decode(codes, enc_feat_size=(2, 1000))

6 torch.Size([1, 3, 500])


In [11]:
torchaudio.save("/Users/tracy/Desktop/recon_our_gan.wav", recon_audio, 16000)