# MusicLM model
### Robert Chen, Ahmadsho Akdodshoev, Philip Timofeev

## 1. Imports

In [None]:
import torch
from torch import nn, distributed as dist
from torch.nn import functional as F
import numpy as np
import einops
from einops.layers.torch import Rearrange
import math
from functools import wraps, partial
from torchaudio import *

## 2. Auxiliary methods

Methods for distributed training

In [None]:
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
    
def _all_gather(tensor):
    world_size = dist.get_world_size()
    gathered = [torch.empty_like(tensor, device=tensor.device, dtype=tensor.dtype) for i in range(world_size)]
    dist.all_gather(gathered, tensor)
    return gathered

def all_gather(tensor, dim=0, sizes=None):
    device, rank, world_size = tensor.device, dist.get_rank(), dist.get_world_size()
    
    if sizes is None:
        size = torch.tensor(tensor.shape[dim], device=device, dtype=torch.long)
        sizes = _all_gather(size)
        sizes = torch.stack(sizes)
        
    if torch.unique(sizes).numel() == 1:
        return torch.cat(_all_gather(tensor), dim=dim), sizes
    
    pad_size = sizes.amax().item()
    padded_tensor = pad_dim_to(tensor, pad_size, dim=dim)
    
    sequence = torch.arange(pad_size, device=device)
    mask = einops.rearrange(einops.rearrange(sequence, 'j -> 1 j') < einops.rearrange(sizes, 'i -> i 1'), 'i j -> (i j)')
    
    sequence = torch.arange(mask.shape[-1], device=device)
    idx = sequence[mask]
    
    gathered = torch.cat(_all_gather(padded_tensor), dim=dim).index_select(dim, idx)
    
    return gathered, sizes

Tensor methods:

## 3. Modules

In [None]:
class AllGatherFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, dim, sizes, grads):
        pass
    
    @staticmethod
    def backward(ctx, grads, _):
        pass
    
class AllGather(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
    def forward(self, x):
        pass
    