### To do list:
#### 1. ~~Скачать датасет *VGGSound*. Ссылка: https://huggingface.co/datasets/Loie/VGGSound/tree/main~~
#### 2. ~~Все видео-файлы заменить на случайный кадр из них~~
#### 2.1 ~~Перевести wav файлы к частоте 48кГц~~
#### 3. ~~Реализовать DataLoader wav->ipeg~~
#### 4. ~~Встроить аудио-энкодер. Ссылка: https://github.com/archinetai/archisound~~
#### 5. ~~Реализовать сценарий обучения и обучить модель~~

In [1]:
import torch
from utils.config import ModelConfig
from models.unet import UNetWithCrossAttention
from models.diffusion import Diffusion
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from utils.SoundDataset import SoundDataset
from torch.utils.data import DataLoader

In [2]:
image_path = "data/images"
sound_path = "data/sounds"

data = SoundDataset(image_path, sound_path)
train_data, val_data = torch.utils.data.random_split(data, [197889-10000, 10000])

train_loader = DataLoader(train_data, 
                          batch_size=96,
                          num_workers=8,
                          pin_memory=True,
                          shuffle=True, 
                          drop_last=True)
                         
val_loader = DataLoader(val_data, 
                        batch_size=8,
                        num_workers=4,
                        pin_memory=True,
                        shuffle=False, 
                        drop_last=True)

In [3]:
from archisound import ArchiSound

device = "cuda" if torch.cuda.is_available() else "cpu"

autoencoder = ArchiSound.from_pretrained("dmae1d-ATC64-v2").to(device)

2025-05-19 11:02:06.942983: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-19 11:02:06.950350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747641726.958814   35632 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747641726.961292   35632 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-19 11:02:06.970762: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [4]:
config = ModelConfig({"image_size": 128, "audio_ctx_dim": 32})

# Инициализация
diffusion = Diffusion(timesteps=1000, image_size=128, device=device)
model = UNetWithCrossAttention(config).to(device)

train_losses = []
val_losses = []

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

unconditional_prob = 0.08

epoch = 4
for _ in tqdm(range(epoch)):
    epo_train_losses = []
    epo_val_losses = []
    model.train()
    for audio, images in train_loader:

        if torch.rand(1) < unconditional_prob:
            audio_embeds = None
        else:
            with torch.no_grad():
                audio_embeds = autoencoder.encode(audio.to(device)) # [B, d_audio, seq_len] [64, 32, 431]
                
            audio_embeds = audio_embeds.permute(0, 2, 1)  # [B, seq_len, d_audio] [64, 431, 32]

        images = images.to(device)
        
        optimizer.zero_grad()
        loss = diffusion.loss_fn(model, images, audio_embeds)
        loss.backward()
        optimizer.step()

        epo_train_losses.append(loss.item())
        
    scheduler.step()

    train_losses.append(sum(epo_train_losses)/len(epo_train_losses))
    # валидация
    model.eval()
    for audio, images in val_loader:
        with torch.no_grad():
            audio_embeds = autoencoder.encode(audio.to(device))
            
            audio_embeds = audio_embeds.permute(0, 2, 1)

            # audio_embeds = torch.zeros((8, 431, 32)).to(device)
            images = images.to(device)
            
            loss = diffusion.loss_fn(model, images, audio_embeds)
    
            epo_val_losses.append(loss.item())

    val_losses.append(sum(epo_val_losses)/len(epo_val_losses))
    

  0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.grid()
plt.legend()

In [None]:
# plt.plot(epo_train_losses, label='train')
plt.plot(epo_train_losses, label='val')
plt.grid()
plt.legend()

In [None]:
audio,image = next(iter(val_loader))
audio_embeds = autoencoder.encode(audio.to(device))
audio_embeds = audio_embeds.permute(0, 2, 1)

model.eval()
generated_image = diffusion.reverse_process(
    model,
    audio_embeds,
    guidance_scale=7.5,
    batch_size=8,
    use_ddim=True,
    timesteps=100  #число шагов
)

# 16 384
# 32 768

In [None]:
plt.imshow(torch.permute(generated_image[0].cpu()*0.5+0.5, (1,2,0)).numpy())

In [None]:
# датасет обходится за 25 мин, 40 сек
# при num_workers = 8 обходится за ~6 мин 30 сек
# 1300 сек на эпоху

# при двух attention блоках средние потери падают до 0.54 на 11 эпохах
# при единственном блоке потери не пробивали отсечку в 0.9

# 3 attention блока дали 0.84 на 4 эпохах. При продолжении обучения к концу 7 эпохи потери не изменились

In [None]:
scheduler.get_lr()

# Пробуем другой кодировщик звука

In [3]:
import torch
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Config

class AudioEncoder:
    def __init__(self,
                model_name,
                device,
                sample_rate = 48000):
        self.sample_rate = sample_rate
        self.device = device
        
        config = Wav2Vec2Config.from_pretrained(model_name)
        config.apply_spec_augment = False
        
        self.model = Wav2Vec2Model.from_pretrained(model_name, config=config).to(device)
        
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)

    def stereo_to_mono(self, audio_tensor):
        return torch.mean(audio_tensor, dim=1, keepdim=True)
        

    def preprocess_audio(self, audio_mono):
        if self.sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(
                orig_freq=self.sample_rate,
                new_freq=16000
            )
            audio_mono = resampler(audio_mono.cpu())
        return audio_mono.to(self.device)
        

    def get_wav2vec2_features(self, audio_stereo_tensor):
        inputs = self.feature_extractor(
            audio_stereo_tensor.squeeze(),
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )
    
        inputs['input_values'] = inputs['input_values'].squeeze().to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
        return last_hidden_state

    def __call__(self, audio):
        audio = self.stereo_to_mono(audio)
        audio = self.preprocess_audio(audio)
        features = self.get_wav2vec2_features(audio)

        return features

2025-05-23 13:49:43.214079: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-23 13:49:43.221137: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747997383.229471   29282 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747997383.232145   29282 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747997383.238766   29282 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "facebook/wav2vec2-base-960h"
sample_rate = 48000

encode = AudioEncoder(
    model_name,
    device,
    sample_rate
)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
config = ModelConfig({"image_size": 128, "audio_ctx_dim": 768})

# Инициализация
diffusion = Diffusion(timesteps=1000, image_size=128, device=device)
model = UNetWithCrossAttention(config).to(device)

train_losses = []
val_losses = []

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

unconditional_prob = 0.08

epoch = 4
for _ in tqdm(range(epoch)):
    epo_train_losses = []
    epo_val_losses = []
    model.train()
    for audio, images in train_loader:

        if torch.rand(1) < unconditional_prob:
            audio_embeds = None
        else:
            audio_embeds = encode(audio.to(device))

        images = images.to(device)
        
        optimizer.zero_grad()
        loss = diffusion.loss_fn(model, images, audio_embeds)
        loss.backward()
        optimizer.step()

        epo_train_losses.append(loss.item())
        
    scheduler.step()

    train_losses.append(sum(epo_train_losses)/len(epo_train_losses))
    # валидация
    model.eval()
    for audio, images in val_loader:
        with torch.no_grad():
            audio_embeds = encode(audio.to(device))

            # audio_embeds = torch.zeros((8, 431, 32)).to(device)
            images = images.to(device)
            
            loss = diffusion.loss_fn(model, images, audio_embeds)
    
            epo_val_losses.append(loss.item())

    val_losses.append(sum(epo_val_losses)/len(epo_val_losses))

  0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 