In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataloader import ARC_Dataset
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import math


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        x2 = self.norm1(x)
        x = x + self.dropout1(self.self_attn(x2, x2, x2, attn_mask=src_mask,
                                             key_padding_mask=src_key_padding_mask)[0])
        x2 = self.norm2(x)
        x = x + self.dropout2(self.linear2(self.dropout(F.relu(self.linear1(x2)))))
        return x