In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models

from torchvision.models import EfficientNet_B0_Weights  # Import the weights

# EfficientNet model
class DualEfficientNetModel(nn.Module):
    def __init__(self, num_classes=2):
        super(DualEfficientNetModel, self).__init__()

        # B0 Model for mel-spectrogram
        # self.efficientnet_mel = models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        self.efficientnet_mel = models.efficientnet_b0(pretrained=True)
        # Modifying first layer to accept grayscale images
        self.efficientnet_mel.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        # B0 Model for MFCC
        self.efficientnet_mfcc = models.efficientnet_b0(pretrained=True)
        # Modifying first layer to accept grayscale images
        self.efficientnet_mfcc.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        # Fully connected layers -> Combining the outputs from both networks
        self.fc1 = nn.Linear(1000 * 2, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, mel_input, mfcc_input):
        mel_features = self.efficientnet_mel(mel_input)
        mfcc_features = self.efficientnet_mfcc(mfcc_input)

        # Concatenate the features from both networks
        combined_features = torch.cat((mel_features, mfcc_features), dim=1)

        # Fully connected layers
        x = self.fc1(combined_features)
        x = self.fc2(x)

        return x

num_classes = 2  # fake or real
model = DualEfficientNetModel(num_classes)
print(model)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 77.9MB/s]


DualEfficientNetModel(
  (efficientnet_mel): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
   