In [1]:
import mgr.predict as predict
import mgr

import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import torchvision

CFG = mgr.configuration.load_configurations()

  return yaml.load(file)


In [3]:
class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_channels, num_head, num_layers, num_classes, dropout=0.0, h_patch=9, w_patch=14, num_patches=33, device="cpu"):
        super().__init__()
        self.device = device
        self.patch_size = (h_patch, w_patch)
        
        self.resnet = torchvision.models.resnet34(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-5])

        self.input_layer = nn.Linear(768, embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_head, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        
    def forward(self, x):
        x = self.resnet(x)

        x = x.flatten(2, 3)
        B, T, _ = x.shape
        
        x = self.input_layer(x)
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)

        x = x + self.pos_embedding[:, :T+1]
        
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        cls = x[0]
        out = self.mlp_head(cls)
        return out

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        input_x = self.layer_norm_1(x)
        x = x + self.attn(input_x, input_x, input_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x
    

if __name__ == "__main__":
    model = VisionTransformer(
            embed_dim=256,
            hidden_dim=512,
            num_head=4,
            num_layers=36,
            h_patch=9,
            w_patch=14,
            num_patches=65,
            num_channels=1,
            num_classes=8,
            dropout=0.3,
            device=CFG['device']
    )
    ckpts = torch.load("/Users/shubhampatel/Downloads/cnn_transformer_v3.pt", map_location=CFG['device'])
    model.load_state_dict(ckpts['model'])
    AUDIO_PATH = "/Users/shubhampatel/Downloads/hiphop.00001.wav"
    print("Genre: ", predict.predict(model, AUDIO_PATH))

tensor([1.2221e-01, 8.0737e-01, 5.2836e-06, 6.9841e-02, 1.5784e-05, 3.7591e-05,
        5.2031e-04, 4.2513e-10])
Genre:  Experimental
