In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

# --- CONFIGURATION FROM DIAGRAM ---
VISUAL_DIM = 256       # Output of Visual Stream (V)
SPATIAL_DIM = 128      # Output of Spatial Stream (S)
AUDIO_ENC_DIM = 512    # Internal Audio Feature Dimension
AUDIO_CHANNELS = 4     # Number of Mics
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {DEVICE}")

Device: cuda


In [9]:
torch.cuda.empty_cache()        # free cached memory
torch.cuda.synchronize()        # wait for all kernels to finish


In [None]:
class VisualStream(nn.Module):
    def __init__(self):
        super(VisualStream, self).__init__()
        # Load ResNet-18
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # Remove classification head
        modules = list(resnet.children())[:-1] 
        self.resnet = nn.Sequential(*modules)
        
        # Project 512 -> 256 (V)
        self.projection = nn.Sequential(
            nn.Linear(512, VISUAL_DIM),
            nn.BatchNorm1d(VISUAL_DIM),
            nn.PReLU()
        )

    def forward(self, x):
        # x: [Batch, 3, Time, 112, 112]
        B, C, T, H, W = x.shape
        
        # Fold Time into Batch
        x = x.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W)
        
        # Extract Features
        x = self.resnet(x)       # [B*T, 512, 1, 1]
        x = x.view(B * T, -1)    # [B*T, 512]
        
        # Project to 256
        x = self.projection(x)   # [B*T, 256]
        
        # Unfold Time
        x = x.view(B, T, -1).permute(0, 2, 1) # [B, 256, Time]
        
        # TODO: Temporal Average Pooling (to get a single vector per clip, optional)
        # Or usually, we upsample this to match audio. 
        # Based on FiLM architectures, we usually keep the time dimension 
        # and upsample it later.
        return x

In [11]:
class SpatialStream(nn.Module):
    def __init__(self, num_mics=4):
        super(SpatialStream, self).__init__()
        
        # We compute GCC-PHAT for all pairs. 
        # For 4 mics, pairs = 4*(3)/2 = 6 pairs.
        self.num_pairs = (num_mics * (num_mics - 1)) // 2
        
        # Spatial CNN Encoder
        # Input: [Batch, Pairs(6), Lags, Time]
        # We treat Pairs as Channels
        self.encoder = nn.Sequential(
            nn.Conv1d(self.num_pairs, 64, kernel_size=1, stride=1),
            nn.BatchNorm1d(64),
            nn.PReLU(),
            nn.Conv1d(64, 128, kernel_size=1, stride=1),
            nn.BatchNorm1d(128),
            nn.PReLU(),
            nn.Conv1d(128, SPATIAL_DIM, kernel_size=1, stride=1)
        )

    def compute_gcc_phat(self, x):
        """
        Compute Generalized Cross-Correlation Phase Transform (GCC-PHAT)
        Input x: [Batch, Mics, Samples]
        """
        B, M, L = x.shape
        
        # 1. FFT
        # n_fft matches window size roughly
        X = torch.fft.rfft(x, dim=-1)
        
        # 2. Compute Pairs
        # We want to cross-correlate every pair (i, j)
        pairs = []
        for i in range(M):
            for j in range(i + 1, M):
                # Cross-spectrum: X_i * conj(X_j)
                R = X[:, i, :] * torch.conj(X[:, j, :])
                # Normalization (PHAT): Divide by magnitude
                R = R / (torch.abs(R) + 1e-8)
                # IFFT to get time-domain correlation
                r = torch.fft.irfft(R, dim=-1)
                
                # Apply shift/lag window (we assume delays are small)
                # This makes it a feature vector per time frame is tricky without STFT.
                # Simplified: We treat the whole clip's correlation as a static spatial signature
                # OR (Better): We perform this on STFT frames. 
                
                # For simplicity in this implementation, we will use a learnable 
                # layer instead of raw GCC-PHAT if raw is too complex to batch.
                # BUT, let's assume the input here is actually the GCC features.
                pairs.append(r)
                
        return torch.stack(pairs, dim=1) # [B, 6, Samples]

    def forward(self, x):
        # x: [Batch, 4, Samples]
        
        # In a real heavy model, we do STFT -> GCC-PHAT -> CNN.
        # Here, we will use a "Learnable Spatial Encoder" which is faster/easier
        # and often outperforms analytical GCC-PHAT.
        
        # 1. Extract correlations implicitly via 1D Conv across channels
        # [B, 4, T] -> [B, 128, T]
        # We pool over time to get a Global Spatial Signature S
        
        gcc_feat = self.compute_gcc_phat(x) # [B, 6, Samples]
        
        # Encode features
        x = self.encoder(gcc_feat) # [B, 128, Samples]
        
        # Global Average Pooling to get single vector S \in R^128
        x = torch.mean(x, dim=-1)  # [B, 128]
        
        return x

In [12]:
class FiLMLayer(nn.Module):
    def __init__(self, in_channels, cond_dim):
        super(FiLMLayer, self).__init__()
        # We map the Conditioning (S+V) to Gamma (Scale) and Beta (Shift)
        self.conv_gamma = nn.Conv1d(cond_dim, in_channels, 1)
        self.conv_beta = nn.Conv1d(cond_dim, in_channels, 1)

    def forward(self, x, condition):
        # x: [Batch, Channels, Time]
        # condition: [Batch, Cond_Dim, Time]
        
        gamma = self.conv_gamma(condition)  # [B, C, T]
        beta = self.conv_beta(condition)    # [B, C, T]
            
        # FiLM Formula: Gamma * x + Beta
        return (gamma * x) + beta

class ExtractionBlock(nn.Module):
    """ TCN Block with FiLM Conditioning """
    def __init__(self, in_channels, hid_channels, cond_dim, dilation):
        super(ExtractionBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, hid_channels, 1)
        self.norm1 = nn.GroupNorm(1, hid_channels)
        self.prelu1 = nn.PReLU()
        
        # FiLM comes after first activation usually
        self.film = FiLMLayer(hid_channels, cond_dim)
        
        self.dconv = nn.Conv1d(hid_channels, hid_channels, 3, 
                               groups=hid_channels, padding=dilation, dilation=dilation)
        self.norm2 = nn.GroupNorm(1, hid_channels)
        self.prelu2 = nn.PReLU()
        
        self.conv2 = nn.Conv1d(hid_channels, in_channels, 1)

    def forward(self, x, condition):
        residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.prelu1(x)
        
        # Apply FiLM Conditioning
        # The condition (S+V) modulates the features here
        x = self.film(x, condition)
        
        x = self.dconv(x)
        x = self.norm2(x)
        x = self.prelu2(x)
        x = self.conv2(x)
        return x + residual

In [13]:
class IsoNet(nn.Module):
    def __init__(self):
        super(IsoNet, self).__init__()
        
        # 1. Streams
        self.visual_stream = VisualStream()  # Output: 256
        self.spatial_stream = SpatialStream(AUDIO_CHANNELS) # Output: 128
        
        # 2. Audio Encoder (Purple box start)
        self.audio_enc = nn.Conv1d(AUDIO_CHANNELS, AUDIO_ENC_DIM, kernel_size=16, stride=8, bias=False)
        
        # 3. Conditioning Prep
        # We concatenate S (128) + V (256) = 384
        self.cond_dim = SPATIAL_DIM + VISUAL_DIM
        
        # 4. TCN with FiLM (Purple box middle)
        self.tcn_blocks = nn.ModuleList([
            ExtractionBlock(AUDIO_ENC_DIM, 128, self.cond_dim, dilation=2**i) 
            for i in range(8)
        ])
        
        # 5. Mask Decoder (Purple box end)
        self.mask_conv = nn.Conv1d(AUDIO_ENC_DIM, AUDIO_ENC_DIM, 1)
        self.sigmoid = nn.Sigmoid()
        
        # 6. Audio Decoder (Reconstructs waveform)
        self.audio_dec = nn.ConvTranspose1d(AUDIO_ENC_DIM, 1, kernel_size=16, stride=8, bias=False)

    def forward(self, audio_mix, video_frames):
        # audio_mix: [B, 4, Samples]
        # video_frames: [B, 3, T_v, H, W]
        
        # --- A. Spatial Stream ---
        # Get global spatial embedding S
        S = self.spatial_stream(audio_mix) # [B, 128]
        
        # --- B. Visual Stream ---
        # Get visual embedding V
        V = self.visual_stream(video_frames) # [B, 256, T_v]
        
        # --- C. Audio Encoding ---
        audio_feat = self.audio_enc(audio_mix) # [B, 512, T_a]
        
        # --- D. Synchronization (Upsampling) ---
        # Video (25 FPS) is slower than Audio Frames. Upsample V to match Audio T_a
        V_upsampled = F.interpolate(V, size=audio_feat.shape[-1], mode='nearest')
        
        # Expand S to match time dimension: [B, 128] -> [B, 128, T_a]
        S_expanded = S.unsqueeze(-1).expand(-1, -1, audio_feat.shape[-1])
        
        # Concatenate S + V to create Conditioning Vector
        # Shape: [B, 384, T_a]
        condition = torch.cat([S_expanded, V_upsampled], dim=1)
        
        # --- E. FiLM Extraction Loop ---
        x = audio_feat
        for block in self.tcn_blocks:
            # We pass the condition to every block
            x = block(x, condition)
            
        # --- F. Masking & Decoding ---
        mask = self.sigmoid(self.mask_conv(x))
        masked_feat = audio_feat * mask
        clean_speech = self.audio_dec(masked_feat)
        
        return clean_speech

In [14]:
# Create Model
model = IsoNet().to(DEVICE)
print(f"IsoNet Created. Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Dummy Data
dummy_audio = torch.randn(2, 4, 64000).to(DEVICE)     # 4 seconds audio
dummy_video = torch.randn(2, 3, 100, 112, 112).to(DEVICE) # 100 frames

# Forward Pass
output = model(dummy_audio, dummy_video)
print(f"Input: {dummy_audio.shape}")
print(f"Output: {output.shape}")

# Check
if output.shape[1] == 1 and abs(output.shape[-1] - 64000) < 100:
    print("IsoNet Architecture matches diagram successfully!")
else:
    print("IsoNet Architecture does not match diagram.")

IsoNet Created. Parameters: 13,488,019
Input: torch.Size([2, 4, 64000])
Output: torch.Size([2, 1, 64000])
IsoNet Architecture matches diagram successfully!
