In [None]:
# default_exp models.layers.encoding

# Encoding Layers
> Implementation of encoding layers.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
import torch
import torch.nn as nn
import torch.nn.functional as F

import copy

from recohut.models.layers.activation import gelu, swish
from recohut.models.layers.attention import SelfAttention, DistSelfAttention, DistMeanSelfAttention

In [None]:
#exporti
ACT2FN = {"gelu": gelu, "relu": F.relu, "swish": swish}

In [None]:
#export
class Intermediate(nn.Module):
    def __init__(self, hidden_size, hidden_act, hidden_dropout_prob):
        super().__init__()
        self.dense_1 = nn.Linear(hidden_size, hidden_size * 4)
        if isinstance(hidden_act, str):
            self.intermediate_act_fn = ACT2FN[hidden_act]
        else:
            self.intermediate_act_fn = hidden_act

        self.dense_2 = nn.Linear(hidden_size * 4, hidden_size)
        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_tensor):

        hidden_states = self.dense_1(input_tensor)
        hidden_states = self.intermediate_act_fn(hidden_states)

        hidden_states = self.dense_2(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states + input_tensor)

        return hidden_states

In [None]:
hidden_size = 4
hidden_act = 'swish'
hidden_dropout_prob = 0.5

layer = Intermediate(hidden_size, hidden_act, hidden_dropout_prob)

input_tensor = torch.rand(4,4)

output = layer.forward(input_tensor)

test_eq(output.shape.numel(), 16)
test_eq(output.shape, [4,4])

In [None]:
#export
class DistIntermediate(nn.Module):
    def __init__(self, hidden_size, hidden_dropout_prob):
        super().__init__()
        self.dense_1 = nn.Linear(hidden_size, hidden_size * 4)
        self.intermediate_act_fn = nn.ELU()

        self.dense_2 = nn.Linear(hidden_size * 4, hidden_size)
        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_tensor):

        hidden_states = self.dense_1(input_tensor)
        hidden_states = self.intermediate_act_fn(hidden_states)

        hidden_states = self.dense_2(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states + input_tensor)

        return hidden_states

In [None]:
hidden_size = 4
hidden_dropout_prob = 0.5

layer = DistIntermediate(hidden_size, hidden_dropout_prob)

input_tensor = torch.rand(4,4)

output = layer.forward(input_tensor)

test_eq(output.shape.numel(), 16)
test_eq(output.shape, [4,4])

In [None]:
#export
class Layer(nn.Module):
    def __init__(self, hidden_size, hidden_act, num_attention_heads, 
                 hidden_dropout_prob, attention_probs_dropout_prob):
        super().__init__()
        self.attention = SelfAttention(hidden_size, num_attention_heads, 
                                       attention_probs_dropout_prob, hidden_dropout_prob)
        self.intermediate = Intermediate(hidden_size, hidden_act, hidden_dropout_prob)

    def forward(self, hidden_states, attention_mask):
        attention_output, attention_scores = self.attention(hidden_states, attention_mask)
        intermediate_output = self.intermediate(attention_output)
        return intermediate_output, attention_scores

In [None]:
hidden_size = 4
hidden_act = 'gelu'
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2

layer = Layer(hidden_size, hidden_act, num_attention_heads, 
              hidden_dropout_prob, attention_probs_dropout_prob)

hidden_states = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

hidden_states = torch.round(layer.forward(hidden_states, attention_mask)[0].detach()*1e4)/1e4

test_eq(hidden_states.shape.numel(), 32)
test_eq(list(hidden_states.shape), [2, 4, 4])

attention_probs = torch.round(layer.forward(hidden_states, attention_mask)[1].detach()*1e4)/1e4

test_eq(attention_probs.shape.numel(), 64)
test_eq(list(attention_probs.shape), [2, 2, 4, 4])

In [None]:
#export
class DistLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob, 
                 attention_probs_dropout_prob, distance_metric='wasserstein'):
        super().__init__()
        self.attention = DistSelfAttention(hidden_size, num_attention_heads, hidden_dropout_prob, 
                                           attention_probs_dropout_prob, distance_metric)
        self.mean_intermediate = DistIntermediate(hidden_size, hidden_dropout_prob)
        self.cov_intermediate = DistIntermediate(hidden_size, hidden_dropout_prob)
        self.activation_func = nn.ELU()

    def forward(self, mean_hidden_states, cov_hidden_states, attention_mask):
        mean_attention_output, cov_attention_output, attention_scores = self.attention(mean_hidden_states, cov_hidden_states, attention_mask)
        mean_intermediate_output = self.mean_intermediate(mean_attention_output)
        cov_intermediate_output = self.activation_func(self.cov_intermediate(cov_attention_output)) + 1
        return mean_intermediate_output, cov_intermediate_output, attention_scores

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2

layer = DistLayer(hidden_size, num_attention_heads, hidden_dropout_prob,
                  attention_probs_dropout_prob)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, input_tensor, attention_mask)

mean_hidden_states = torch.round(output[0].detach()*1e4)/1e4

test_eq(mean_hidden_states.shape.numel(), 32)
test_eq(list(mean_hidden_states.shape), [2, 4, 4])

cov_hidden_states = torch.round(output[1].detach()*1e4)/1e4

test_eq(cov_hidden_states.shape.numel(), 32)
test_eq(list(cov_hidden_states.shape), [2, 4, 4])

attention_probs = torch.round(output[2].detach()*1e4)/1e4

test_eq(attention_probs.shape.numel(), 64)
test_eq(list(attention_probs.shape), [2, 2, 4, 4])

In [None]:
#export
class DistMeanSALayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob, 
                 attention_probs_dropout_prob):
        super().__init__()
        self.attention = DistMeanSelfAttention(hidden_size, num_attention_heads, hidden_dropout_prob, 
                                               attention_probs_dropout_prob)
        self.mean_intermediate = DistIntermediate(hidden_size, hidden_dropout_prob)
        self.cov_intermediate = DistIntermediate(hidden_size, hidden_dropout_prob)
        self.activation_func = nn.ELU()

    def forward(self, mean_hidden_states, cov_hidden_states, attention_mask):
        mean_attention_output, cov_attention_output, attention_scores = self.attention(mean_hidden_states, cov_hidden_states, attention_mask)
        mean_intermediate_output = self.mean_intermediate(mean_attention_output)
        cov_intermediate_output = self.activation_func(self.cov_intermediate(cov_attention_output)) + 1
        return mean_intermediate_output, cov_intermediate_output, attention_scores

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2

layer = DistMeanSALayer(hidden_size, num_attention_heads, hidden_dropout_prob,
                  attention_probs_dropout_prob)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, input_tensor, attention_mask)

mean_hidden_states = torch.round(output[0].detach()*1e4)/1e4

test_eq(mean_hidden_states.shape.numel(), 32)
test_eq(list(mean_hidden_states.shape), [2, 4, 4])

cov_hidden_states = torch.round(output[1].detach()*1e4)/1e4

test_eq(cov_hidden_states.shape.numel(), 32)
test_eq(list(cov_hidden_states.shape), [2, 4, 4])

attention_probs = torch.round(output[2].detach()*1e4)/1e4

test_eq(attention_probs.shape.numel(), 64)
test_eq(list(attention_probs.shape), [2, 2, 4, 4])

In [None]:
#export
class DistSAEncoder(nn.Module):               
    def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob, 
                 attention_probs_dropout_prob, num_hidden_layers,
                 distance_metric='wasserstein'):
        super().__init__()
        layer = DistLayer(hidden_size, num_attention_heads, hidden_dropout_prob, 
                          attention_probs_dropout_prob, distance_metric)
        self.layer = nn.ModuleList([copy.deepcopy(layer)
                                    for _ in range(num_hidden_layers)])

    def forward(self, mean_hidden_states, cov_hidden_states, attention_mask, output_all_encoded_layers=True):
        all_encoder_layers = []
        for layer_module in self.layer:
            maen_hidden_states, cov_hidden_states, att_scores = layer_module(mean_hidden_states, cov_hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append([mean_hidden_states, cov_hidden_states, att_scores])
        if not output_all_encoded_layers:
            all_encoder_layers.append([mean_hidden_states, cov_hidden_states, att_scores])
        return all_encoder_layers

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2
num_hidden_layers = 2

layer = DistSAEncoder(hidden_size, num_attention_heads, hidden_dropout_prob,
                  attention_probs_dropout_prob, num_hidden_layers)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, input_tensor, attention_mask)
output_shapes = [list(x.shape) for x in [j for sub in output for j in sub]]

expected_shapes = [[2, 4, 4], [2, 4, 4], [2, 2, 4, 4], [2, 4, 4], [2, 4, 4], [2, 2, 4, 4]]

test_eq(output_shapes, expected_shapes)

In [None]:
#export
class DistMeanSAEncoder(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, hidden_dropout_prob, 
                 attention_probs_dropout_prob, num_hidden_layers):
        super().__init__()
        layer = DistMeanSALayer(hidden_size, num_attention_heads, hidden_dropout_prob, 
                 attention_probs_dropout_prob)
        self.layer = nn.ModuleList([copy.deepcopy(layer)
                                    for _ in range(num_hidden_layers)])

    def forward(self, mean_hidden_states, cov_hidden_states, attention_mask, output_all_encoded_layers=True):
        all_encoder_layers = []
        for layer_module in self.layer:
            maen_hidden_states, cov_hidden_states, att_scores = layer_module(mean_hidden_states, cov_hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append([mean_hidden_states, cov_hidden_states, att_scores])
        if not output_all_encoded_layers:
            all_encoder_layers.append([mean_hidden_states, cov_hidden_states, att_scores])
        return all_encoder_layers

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2
num_hidden_layers = 2

layer = DistMeanSAEncoder(hidden_size, num_attention_heads, hidden_dropout_prob,
                          attention_probs_dropout_prob, num_hidden_layers)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, input_tensor, attention_mask)
output_shapes = [list(x.shape) for x in [j for sub in output for j in sub]]

expected_shapes = [[2, 4, 4], [2, 4, 4], [2, 2, 4, 4], [2, 4, 4], [2, 4, 4], [2, 2, 4, 4]]

test_eq(output_shapes, expected_shapes)

In [None]:
#export
class Encoder(nn.Module):
    def __init__(self, hidden_size, hidden_act, num_attention_heads, 
                 hidden_dropout_prob, attention_probs_dropout_prob,
                 num_hidden_layers):
        super().__init__()
        layer = Layer(hidden_size, hidden_act, num_attention_heads, 
                 hidden_dropout_prob, attention_probs_dropout_prob)
        self.layer = nn.ModuleList([copy.deepcopy(layer)
                                    for _ in range(num_hidden_layers)])

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
        all_encoder_layers = []
        for layer_module in self.layer:
            hidden_states, attention_scores = layer_module(hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append([hidden_states, attention_scores])
        if not output_all_encoded_layers:
            all_encoder_layers.append([hidden_states, attention_scores])
        return all_encoder_layers

In [None]:
hidden_size = 4
hidden_act = 'swish'
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2
num_hidden_layers = 2

layer = Encoder(hidden_size, hidden_act, num_attention_heads, hidden_dropout_prob,
                          attention_probs_dropout_prob, num_hidden_layers)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, attention_mask)
output_shapes = [list(x.shape) for x in [j for sub in output for j in sub]]

expected_shapes = [[2, 4, 4], [2, 2, 4, 4], [2, 4, 4], [2, 2, 4, 4]]

test_eq(output_shapes, expected_shapes)

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d -p recohut

Author: Sparsh A.

Last updated: 2022-01-22 16:49:16

recohut: 0.0.11

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

torch    : 1.10.0+cu111
IPython  : 5.5.0
watermark: 2.3.0

