In [1]:
import sys
print(sys.version)

3.8.18 (default, Sep 11 2023, 13:40:15) 
[GCC 11.2.0]


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
import torch.nn as nn
import math
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
for i in range(torch.cuda.device_count()):
    print(i, torch.cuda.get_device_name(i))

0 NVIDIA GeForce RTX 3090


In [3]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda:0


In [4]:
batch_size = 16
n_epochs = 500
model_no = 'transformer_with_word2vec'
exp = 1

In [27]:
class custon_transformer(nn.Module):

    def __init__(self, no_out_vect = 128, num_classes=128, hidden_dim=512, nheads=8, num_encoder_layers=5, 
                 num_decoder_layers=5):
        super(custon_transformer, self).__init__()
        self.hidden_dim = hidden_dim
#         self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers, 
#                                           batch_first=True, activation="relu")
        self.encoder = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nheads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder, num_layers = num_encoder_layers)
        
        self.decoder = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nheads)
        self.transformer_decoder = nn.TransformerEncoder(self.encoder, num_layers = num_encoder_layers)
        
        self.relu_layer = nn.ReLU()
        self.sigmoid_layer = nn.Sigmoid()
        
    def positionalencoding1d(self, d_model, length):
        """
        :param d_model: dimension of the model
        :param length: length of positions
        :return: length*d_model position matrix
        """
        if d_model % 2 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with "
                             "odd dim (got dim={:d})".format(d_model))
        pe = torch.zeros(length, d_model)
        position = torch.arange(0, length).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                             -(math.log(10000.0) / d_model)))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)

        return pe

    def forward(self, feat_input):
        feat_input = feat_input.flatten(2).permute(0, 2, 1)

        feat_input += self.positionalencoding1d(self.hidden_dim, feat_input.shape[-2]).repeat(feat_input.shape[0], 1, 1)
#         features = self.transformer(feat_input.cuda(), self.learnable_query.repeat(feat_input.shape[0], 1, 1))
        enc_features = self.transformer_encoder(feat_input)
        dec_features = self.transformer_decoder(feat_input, enc_features)
        features = self.linear1(features.flatten(1))
        features = self.sigmoid_layer(features)

        return features

In [28]:
new_model = custon_transformer()
# new_model.to(device)
print("Total trainable params:", torch.nn.utils.parameters_to_vector([p for p in new_model.parameters() if p.requires_grad]).numel())

Total trainable params: 37828608


In [29]:
new_model

custon_transformer(
  (encoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-4): 5 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
    