### To do list:
#### 1. ~~Скачать датасет *VGGSound*. Ссылка: https://huggingface.co/datasets/Loie/VGGSound/tree/main~~
#### 2. ~~Все видео-файлы заменить на случайный кадр из них~~
#### 2.1 ~~Перевести wav файлы к частоте 48кГц~~
#### 3. ~~Реализовать DataLoader wav->ipeg~~
#### 4. ~~Встроить аудио-энкодер. Ссылка: https://github.com/archinetai/archisound~~
#### 5. ~~Реализовать сценарий обучения и обучить модель~~

In [1]:
import torch
from utils.config import ModelConfig
from models.unet import UNetWithCrossAttention
from models.diffusion import Diffusion
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [2]:
from torch.utils.data import Dataset, DataLoader
import glob
from torchvision.io import read_image
import torchaudio
import torch

class SoundDataset(Dataset):
    def __init__(self , image_path, sound_path):
        names = []
        for path in glob.glob(f'{image_path}/*.jpg'):
            name = path.split('/')[-1][:-4]
            names.append(name)

        self.names = names
        self.im_path = image_path
        self.au_path = sound_path
        self.stanart_len = 441000

    def __len__(self):
        return len(self.names)

    def __getitem__(self , index):
        image = read_image(f"{self.im_path}/{self.names[index]}.jpg")/255.0
        
        audio_path = f"{self.au_path}/{self.names[index]}.wav"
        waveform, _ = torchaudio.load(audio_path)
        
        # необходим стерео звук, если он одноканальный то приводим к стерео
        if waveform.shape[0] == 1:
            stereo = torch.zeros((2, waveform.shape[1]), dtype=torch.float)
            stereo[0] = waveform[0]
            stereo[1] = waveform[0]
            waveform = stereo

        elif waveform.shape[0] != 2:
            raise ValueError(f"audio {self.names[index]} must be stereo or mono, but {waveform.shape[0]} channels were given")

        # все тензоры должны быть стандартного размера (только для обучения)
        if waveform.shape[1] < self.stanart_len:
            ext_waveform = torch.zeros((2, self.stanart_len), dtype=torch.float)
            ext_waveform[:, :waveform.shape[1]] = waveform
            waveform = ext_waveform
        elif waveform.shape[1] > self.stanart_len:
            waveform = waveform[:, :self.stanart_len]
        
        return waveform.float(), image.float()

In [3]:
from torch.utils.data import DataLoader

image_path = "data/images"
sound_path = "data/sounds"

data = SoundDataset(image_path, sound_path)
train_data, val_data = torch.utils.data.random_split(data, [197889-10000, 10000])

train_loader = DataLoader(train_data, 
                          batch_size=64,
                          num_workers=8,
                          pin_memory=True,
                          shuffle=True, 
                          drop_last=True)
                         
val_loader = DataLoader(val_data, 
                        batch_size=8,
                        num_workers=8,
                        pin_memory=True,
                        shuffle=False, 
                        drop_last=True)

In [4]:
from archisound import ArchiSound

device = "cuda" if torch.cuda.is_available() else "cpu"

autoencoder = ArchiSound.from_pretrained("dmae1d-ATC64-v2").to(device)

2025-05-07 20:25:37.351518: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-07 20:25:37.359633: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746638737.368425  317065 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746638737.371001  317065 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-07 20:25:37.380732: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [5]:
config = ModelConfig({
                    "image_size": 128,
                    "audio_ctx_dim": 32
                    })

# Инициализация
diffusion = Diffusion(timesteps=1000, image_size=128, device=device)
model = UNetWithCrossAttention(config).to(device)

# Обучение
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

train_losses = []
val_losses = []

In [6]:
epoch = 1
for _ in tqdm(range(epoch)):
    epo_train_losses = []
    epo_val_losses = []
    for audio, images in train_loader:
        with torch.no_grad():
            audio_embeds = autoencoder.encode(audio.to(device)) # [B, d_audio, seq_len]
            
        audio_embeds = audio_embeds.permute(0, 2, 1)  # [B, seq_len, d_audio]
        images = images.to(device)
        
        optimizer.zero_grad()
        loss = diffusion.loss_fn(model, images, audio_embeds)
        loss.backward()
        optimizer.step()

        epo_train_losses.append(loss.item())

    train_losses.append(sum(epo_train_losses)/len(epo_train_losses))
    # валидация
    for audio, images in val_loader:
        with torch.no_grad():
            audio_embeds = autoencoder.encode(audio.to(device))
            
            audio_embeds = audio_embeds.permute(0, 2, 1)
            images = images.to(device)
            
            loss = diffusion.loss_fn(model, images, audio_embeds)
    
            epo_val_losses.append(loss.item())

    val_losses.append(sum(epo_val_losses)/len(epo_val_losses))
    

  0%|          | 0/1 [00:00<?, ?it/s]

Exception in thread Thread-6 (_pin_memory_loop):
Traceback (most recent call last):
  File "/home/usr/miniforge3/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/home/usr/miniforge3/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/usr/miniforge3/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/home/usr/miniforge3/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/home/usr/miniforge3/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usr/miniforge3/lib/python3.12/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File 

KeyboardInterrupt: 

In [None]:
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.grid()
plt.legend()

In [None]:
audio,image = next(iter(val_loader))
audio_embeds = autoencoder.encode(audio.to(device))
audio_embeds = audio_embeds.permute(0, 2, 1)

generated_image = diffusion.reverse_process(
    model,
    audio_embeds,
    batch_size=8,
    use_ddim=True,
    timesteps= 50  # во сколько раз сокращаем число шагов
)

In [None]:
plt.imshow(torch.permute(generated_image[0].cpu(), (1,2,0)).numpy())

In [None]:
# датасет обходится за 25 мин, 40 сек
# при num_workers = 8 обходится за ~6 мин 30 сек
# 1300 сек на эпоху

In [None]:
upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
test = torch.rand((8,3,16,16))

upsample(test).shape

In [None]:
upsample = torch.nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1)

upsample(test, output_size=(-1, 3, 32, 32)).shape