This is the Colab to reproduce our reproduction of the text style transfer task in [Beyond Fully Connected Layers with Quaternions: Parametrization of Hypercomplex Multiplications with 1/n Parameters](https://openreview.net/pdf?id=rcQdycl0zyk). We will translate German to English, using a PHM Transformer. You can find the detailed explanation in our [Github Repo](https://github.com/sinankalkan/CENG501-Spring2021/tree/main/project_BarutcuDemir). 

This notebook is not annotated in detail and not pretty compared to the style transfer notebook. Most of the code is same except for the data download, German tokenization and the training loop. For other parts please refer to that notebook that you can access in our Github repo linked above.

Parts of the code are adapted from the Torch source code and [this tutorial](https://pytorch.org/tutorials/beginner/translation_transformer.html).

# Install Libraries

In [None]:
!pip install torchinfo wandb

Collecting torchinfo
  Downloading torchinfo-1.5.2-py3-none-any.whl (18 kB)
Collecting wandb
  Downloading wandb-0.11.0-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 27.0 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.20-py3-none-any.whl (178 kB)
[K     |████████████████████████████████| 178 kB 65.6 MB/s 
[?25hCollecting configparser>=3.8.1
  Downloading configparser-5.0.2-py3-none-any.whl (19 kB)
Collecting subprocess32>=3.5.3
  Downloading subprocess32-3.5.4.tar.gz (97 kB)
[K     |████████████████████████████████| 97 kB 6.6 MB/s 
Collecting sentry-sdk>=0.4.0
  Downloading sentry_sdk-1.3.1-py2.py3-none-any.whl (133 kB)
[K     |████████████████████████████████| 133 kB 63.2 MB/s 
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting urllib3>=1.26.5
  Downloading urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
[K     |

In [None]:
!pip install torch==1.8.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.8.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torch-1.8.0%2Bcu111-cp37-cp37m-linux_x86_64.whl (1982.2 MB)
[K     |█████████████▌                  | 834.1 MB 814 kB/s eta 0:23:31tcmalloc: large alloc 1147494400 bytes == 0x55c002b7a000 @  0x7f5fda18a615 0x55bfc9de002c 0x55bfc9ec017a 0x55bfc9de2e4d 0x55bfc9ed4c0d 0x55bfc9e570d8 0x55bfc9e51c35 0x55bfc9de473a 0x55bfc9e56f40 0x55bfc9e51c35 0x55bfc9de473a 0x55bfc9e5393b 0x55bfc9ed5a56 0x55bfc9e52fb3 0x55bfc9ed5a56 0x55bfc9e52fb3 0x55bfc9ed5a56 0x55bfc9e52fb3 0x55bfc9de4b99 0x55bfc9e27e79 0x55bfc9de37b2 0x55bfc9e56e65 0x55bfc9e51c35 0x55bfc9de473a 0x55bfc9e5393b 0x55bfc9e51c35 0x55bfc9de473a 0x55bfc9e52b0e 0x55bfc9de465a 0x55bfc9e52d67 0x55bfc9e51c35
[K     |█████████████████               | 1055.7 MB 1.3 MB/s eta 0:12:12tcmalloc: large alloc 1434370048 bytes == 0x55c0471d0000 @  0x7f5fda18a615 0x55bfc9de002c 0x55bfc9ec017a 0x55bfc9de2e4

In [None]:
!pip install torchtext==0.9.0

Collecting torchtext==0.9.0
  Downloading torchtext-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (7.1 MB)
[K     |████████████████████████████████| 7.1 MB 18.1 MB/s 
Installing collected packages: torchtext
  Attempting uninstall: torchtext
    Found existing installation: torchtext 0.10.0
    Uninstalling torchtext-0.10.0:
      Successfully uninstalled torchtext-0.10.0
Successfully installed torchtext-0.9.0


In [None]:
!nvidia-smi

Wed Jul 28 07:36:55 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   67C    P8    12W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!wget http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz

--2021-07-24 07:31:45--  http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz
Resolving www.statmt.org (www.statmt.org)... 129.215.197.184
Connecting to www.statmt.org (www.statmt.org)|129.215.197.184|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 657632379 (627M) [application/x-gzip]
Saving to: ‘training-parallel-europarl-v7.tgz’

       training-par   0%[                    ] 898.38K  1.07MB/s               ^C


In [None]:
!tar -xzvf "/content/training-parallel-europarl-v7.tgz" -C "/content/" 

training/europarl-v7.cs-en.cs
training/europarl-v7.cs-en.en
training/europarl-v7.de-en.de
training/europarl-v7.de-en.en
training/europarl-v7.es-en.en
training/europarl-v7.es-en.es
training/europarl-v7.fr-en.en
training/europarl-v7.fr-en.fr


German tokenizer

In [None]:
!wget https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz

--2021-07-28 07:37:55--  https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-releases.githubusercontent.com/84940268/920ce580-06fe-11ea-9159-a0235b43e9b7?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20210728%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20210728T073755Z&X-Amz-Expires=300&X-Amz-Signature=0086013d32d81526c15ca41e208725a0b4fbef48386877b2c573eed15d27d209&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=84940268&response-content-disposition=attachment%3B%20filename%3Dde_core_news_sm-2.2.5.tar.gz&response-content-type=application%2Foctet-stream [following]
--2021-07-28 07:37:55--  https://github-releases.githubusercontent.com/84940268/920ce580-06fe-11ea-9159-a0235b43e9b7?X-Amz-Algorithm=AWS4-HMAC-

In [None]:
!tar -xzvf '/content/de_core_news_sm-2.2.5.tar.gz'

de_core_news_sm-2.2.5/
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/PKG-INFO
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/dependency_links.txt
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/top_level.txt
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/SOURCES.txt
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/requires.txt
de_core_news_sm-2.2.5/de_core_news_sm.egg-info/not-zip-safe
de_core_news_sm-2.2.5/setup.cfg
de_core_news_sm-2.2.5/PKG-INFO
de_core_news_sm-2.2.5/MANIFEST.in
de_core_news_sm-2.2.5/meta.json
de_core_news_sm-2.2.5/setup.py
de_core_news_sm-2.2.5/de_core_news_sm/
de_core_news_sm-2.2.5/de_core_news_sm/__init__.py
de_core_news_sm-2.2.5/de_core_news_sm/meta.json
de_core_news_sm-2.2.5/de_core_news_sm/de_core_news_sm-2.2.5/
de_core_news_sm-2.2.5/de_core_news_sm/de_core_news_sm-2.2.5/ner/
de_core_news_sm-2.2.5/de_core_news_sm/de_core_news_sm-2.2.5/ner/model
de_core_news_sm-2.2.5/de_core_news_sm/de_core_news_sm-2.2.5/ner/moves
de

You should download the [dev data](http://www.statmt.org/wmt14/dev.tgz)

In [None]:
!tar -xzvf "/path to dev/" -C "/content/" 

dev/
dev/newstest2009-src.es.sgm
dev/newstest2009.cz
dev/newstest2013-ref.en.sgm
dev/news-test2008.es
dev/newstest2010-ref.de.sgm
dev/newstest2010-ref.cz.sgm
dev/newssyscomb2009.cs
dev/newstest2013-src.en.sgm
dev/newstest2011-src.de.sgm
dev/news-test2008-ref.de.sgm
dev/newsdev2014-ref.hi.sgm
dev/newstest2011.de
dev/newstest2010-ref.en.sgm
dev/newstest2009-src.cz.sgm
dev/newstest2011-ref.en.sgm
dev/newssyscomb2009-ref.es.sgm
dev/newstest2009.en
dev/newssyscomb2009-ref.cs.sgm
dev/newssyscomb2009.en
dev/newstest2011-src.en.sgm
dev/newstest2010-src.es.sgm
dev/news-test2008-src.es.sgm
dev/newstest2009-ref.it.sgm
dev/newstest2009-ref.fr.sgm
dev/newstest2012.es
dev/newstest2012-ref.es.sgm
dev/newstest2010.de
dev/newstest2012-src.es.sgm
dev/newssyscomb2009-src.en.sgm
dev/newstest2010.en
dev/news-test2008-ref.cs.sgm
dev/newstest2011.fr
dev/newstest2009-src.fr.sgm
dev/newstest2011-ref.es.sgm
dev/newstest2009-src.hu.sgm
dev/newssyscomb2009-src.es.sgm
dev/newstest2011-src.fr.sgm
dev/newssyscomb200

In [None]:
import math
import torchtext
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
from torchtext.utils import download_from_url, extract_archive
from torch import Tensor
import io
import time

torch.manual_seed(0)

<torch._C.Generator at 0x7fea24f8cc90>

Load German tokenizer

In [None]:
import spacy
spacy.load('/content/de_core_news_sm-2.2.5/de_core_news_sm/de_core_news_sm-2.2.5')


<spacy.lang.de.German at 0x7fea239b2a90>

Make sure the paths are correct for this one. This process takes a long time

In [None]:
train_filepaths = ['/path to /europarl-v7.de-en.de', '/path to /europarl-v7.de-en.en']
val_filepaths = ['/path to /newstest2010.de', '/path to /newstest2010.en']

en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
de_tokenizer = get_tokenizer('spacy', language='/content/de_core_news_sm-2.2.5/de_core_news_sm/de_core_news_sm-2.2.5')

def build_vocab(filepath, tokenizer):
  counter = Counter()
  with io.open(filepath, encoding="utf8") as f:
    for string_ in f:
      counter.update(tokenizer(string_))
  return Vocab(counter, max_size=32764, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

german_vocab = build_vocab('/path to /europarl-v7.de-en.de', de_tokenizer) 
english_vocab = build_vocab('/path to /europarl-v7.de-en.en', en_tokenizer)
#10 dk

This also takes a long time

In [None]:
def data_process(filepaths):
  raw_de_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter):
    mo_tensor_ = torch.tensor([german_vocab[token] for token in de_tokenizer(raw_de.rstrip("\n"))],
                            dtype=torch.long)
    og_tensor_ = torch.tensor([english_vocab[token] for token in en_tokenizer(raw_en.rstrip("\n"))],
                            dtype=torch.long)
    data.append((mo_tensor_, og_tensor_))
  return data

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 16
PAD_IDX = german_vocab['<pad>']
BOS_IDX = german_vocab['<bos>']
EOS_IDX = german_vocab['<eos>']

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
  de_batch, en_batch = [], []
  for (de_item, en_item) in data_batch:
    de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
    en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
  de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
  en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
  return de_batch, en_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)

# The PHM Transformer

In [None]:
from torch.nn.parameter import Parameter, UninitializedParameter
import torch
from torch import nn
from torch import Tensor
from torch.nn import init
import math
from torch.nn import functional as F

class PHMLayer(nn.Module):

  def __init__(self, n, in_features, out_features):
    super(PHMLayer, self).__init__()
    self.n = n
    self.in_features = in_features
    self.out_features = out_features

    self.bias = Parameter(torch.Tensor(out_features))

    self.a = torch.zeros((n, n, n))
    self.a = Parameter(torch.nn.init.xavier_uniform_(self.a))

    self.s = torch.zeros((n, self.out_features//n, self.in_features//n)) 
    self.s = Parameter(torch.nn.init.xavier_uniform_(self.s))

    self.weight = torch.zeros((self.out_features, self.in_features))

    fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
    bound = 1 / math.sqrt(fan_in)
    init.uniform_(self.bias, -bound, bound)

  def kronecker_product1(self, a, b):
    siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
    res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
    siz0 = res.shape[:-4]
    out = res.reshape(siz0 + siz1)
    return out

  def forward(self, input: Tensor) -> Tensor:
    self.weight = torch.sum(self.kronecker_product1(self.a, self.s), dim=0)
    input = input.type(dtype=self.weight.type())
    return F.linear(input, weight=self.weight, bias=self.bias)

  def extra_repr(self) -> str:
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None)
    
  def reset_parameters(self) -> None:
    init.kaiming_uniform_(self.a, a=math.sqrt(5))
    init.kaiming_uniform_(self.s, a=math.sqrt(5))
    fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
    bound = 1 / math.sqrt(fan_in)
    init.uniform_(self.bias, -bound, bound)

In [None]:
import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor
#from torch.nn import NonDynamicallyQuantizableLinear
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch import functional as F

In [None]:
from torch.nn.functional import *

In [None]:
import copy
from typing import Optional, Any

import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn import Module
from torch.nn import MultiheadAttention
from torch.nn import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn import Dropout
from torch.nn import Linear
from torch.nn import LayerNorm

In [None]:
def phm_multi_head_attention_forward(
    phm: int,
    a: Tensor,
    s: Tensor,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Tensor,
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Tensor,
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        use_separate_proj_weight: the function accept the proj. weights for query, key,
            and value in different forms. If false, in_proj_weight will be used, which is
            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
        static_k, static_v: static key and value used for attention operators.


    Shape:
        Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
          will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.

        Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
    """
    tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
    if has_torch_function(tens_ops):
        return handle_torch_function(
            multi_head_attention_forward,
            tens_ops,
            query,
            key,
            value,
            embed_dim_to_check,
            num_heads,
            in_proj_weight,
            in_proj_bias,
            bias_k,
            bias_v,
            add_zero_attn,
            dropout_p,
            out_proj_weight,
            out_proj_bias,
            training=training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
            use_separate_proj_weight=use_separate_proj_weight,
            q_proj_weight=q_proj_weight,
            k_proj_weight=k_proj_weight,
            v_proj_weight=v_proj_weight,
            static_k=static_k,
            static_v=static_v,
        )
    tgt_len, bsz, embed_dim = query.size()
    assert embed_dim == embed_dim_to_check
    # allow MHA to have different sizes for the feature dimension
    assert key.size(0) == value.size(0) and key.size(1) == value.size(1)

    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
    scaling = float(head_dim) ** -0.5

    if not use_separate_proj_weight:
        if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
            # self-attention
            q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)

        elif key is value or torch.equal(key, value):
            # encoder-decoder attention
            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = linear(query, _w, _b)

            if key is None:
                assert value is None
                k = None
                v = None
            else:

                # This is inline in_proj function with in_proj_weight and in_proj_bias
                _b = in_proj_bias
                _start = embed_dim
                _end = None
                _w = in_proj_weight[_start:, :]
                if _b is not None:
                    _b = _b[_start:]
                k, v = linear(key, _w, _b).chunk(2, dim=-1)

        else:
            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = 0
            _end = embed_dim
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            q = linear(query, _w, _b)

            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim
            _end = embed_dim * 2
            _w = in_proj_weight[_start:_end, :]
            if _b is not None:
                _b = _b[_start:_end]
            k = linear(key, _w, _b)

            # This is inline in_proj function with in_proj_weight and in_proj_bias
            _b = in_proj_bias
            _start = embed_dim * 2
            _end = None
            _w = in_proj_weight[_start:, :]
            if _b is not None:
                _b = _b[_start:]
            v = linear(value, _w, _b)
    else:
        q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
        len1, len2 = q_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == query.size(-1)

        k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
        len1, len2 = k_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == key.size(-1)

        v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
        len1, len2 = v_proj_weight_non_opt.size()
        assert len1 == embed_dim and len2 == value.size(-1)

        if in_proj_bias is not None:
            q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
            k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
            v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
        else:
            q = linear(query, q_proj_weight_non_opt, in_proj_bias)
            k = linear(key, k_proj_weight_non_opt, in_proj_bias)
            v = linear(value, v_proj_weight_non_opt, in_proj_bias)
    q = q * scaling

    if attn_mask is not None:
        assert (
            attn_mask.dtype == torch.float32
            or attn_mask.dtype == torch.float64
            or attn_mask.dtype == torch.float16
            or attn_mask.dtype == torch.uint8
            or attn_mask.dtype == torch.bool
        ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)

        if attn_mask.dim() == 2:
            attn_mask = attn_mask.unsqueeze(0)
            if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
                raise RuntimeError("The size of the 2D attn_mask is not correct.")
        elif attn_mask.dim() == 3:
            if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
                raise RuntimeError("The size of the 3D attn_mask is not correct.")
        else:
            raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
        # attn_mask's dim is 3 now.

    # convert ByteTensor key_padding_mask to bool
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn(
            "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
        )
        key_padding_mask = key_padding_mask.to(torch.bool)

    if bias_k is not None and bias_v is not None:
        if static_k is None and static_v is None:
            k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = pad(key_padding_mask, (0, 1))
        else:
            assert static_k is None, "bias cannot be added to static key."
            assert static_v is None, "bias cannot be added to static value."
    else:
        assert bias_k is None
        assert bias_v is None

    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if k is not None:
        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    if v is not None:
        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)

    if static_k is not None:
        assert static_k.size(0) == bsz * num_heads
        assert static_k.size(2) == head_dim
        k = static_k

    if static_v is not None:
        assert static_v.size(0) == bsz * num_heads
        assert static_v.size(2) == head_dim
        v = static_v

    src_len = k.size(1)

    if key_padding_mask is not None:
        assert key_padding_mask.size(0) == bsz
        assert key_padding_mask.size(1) == src_len

    if add_zero_attn:
        src_len += 1
        k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
        v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))

    attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_output_weights.masked_fill_(attn_mask, float("-inf"))
        else:
            attn_output_weights += attn_mask

    if key_padding_mask is not None:
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        attn_output_weights = attn_output_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float("-inf"),
        )
        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

    attn_output_weights = softmax(attn_output_weights, dim=-1)
    attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)

    attn_output = torch.bmm(attn_output_weights, v)
    assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)

    #attn_output = linear(attn_output, out_proj_weight, out_proj_bias)

    def kronecker_product1(a, b):
      siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
      res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
      siz0 = res.shape[:-4]
      out = res.reshape(siz0 + siz1)
      return out

    out_proj_weight2 = torch.sum(kronecker_product1(a, s), dim=0)


    attn_output = linear(attn_output, out_proj_weight2, out_proj_bias)




    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None

In [None]:
class PHMMultiheadAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces.
    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.

    Args:
        embed_dim: total dimension of the model.
        num_heads: parallel attention heads.
        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
        bias: add bias as module parameter. Default: True.
        add_bias_kv: add bias to the key and value sequences at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        kdim: total number of features in key. Default: None.
        vdim: total number of features in value. Default: None.
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
    to :attr:`embed_dim` such that query, key, and value have the same
    number of features.

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """
    __constants__ = ['batch_first']
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(self, phm, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(PHMMultiheadAttention, self).__init__()
        self.phm = phm
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            #self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            #self.in_proj_weight = PHMLayer(2, 3 * embed_dim, embed_dim)
            
            #self.in_proj_weight = torch.zeros((embed_dim, 3 * embed_dim)) #out in

            self.a = Parameter(torch.nn.init.xavier_uniform_(torch.zeros((phm, phm, phm))))

            self.s = Parameter(torch.nn.init.xavier_uniform_(torch.zeros((phm, (3*embed_dim)//phm, embed_dim//phm))))
            
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        
        #self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
        self.out_proj = PHMLayer(phm, embed_dim, embed_dim)


        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            #xavier_uniform_(self.in_proj_weight)
            pass
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def kronecker_product1(self, a, b):
        siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
        res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
        siz0 = res.shape[:-4]
        out = res.reshape(siz0 + siz1)
        return out

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(PHMMultiheadAttention, self).__setstate__(state)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None):
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. When given a binary mask and a value is True,
            the corresponding value on the attention layer will be ignored. When given
            a byte mask and a value is non-zero, the corresponding value on the attention
            layer will be ignored
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.

    Shapes for inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the position
          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
          source sequence length.

          If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
          length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
          the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.

    Shapes for outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """
        self.in_proj_weight = torch.sum(self.kronecker_product1(self.a, self.s), dim=0)

        if self.batch_first:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = phm_multi_head_attention_forward(
                self.phm, self.out_proj.a, self.out_proj.s, query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
                )
        else:
            attn_output, attn_output_weights = phm_multi_head_attention_forward(
                self.phm, self.out_proj.a, self.out_proj.s, query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
        if self.batch_first:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

In [None]:
def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

class PHMTransformerEncoderLayer(Module):

    __constants__ = ['batch_first']

    def __init__(self, phm, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
                 layer_norm_eps=1e-5, batch_first=False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(PHMTransformerEncoderLayer, self).__init__()
        self.self_attn = PHMMultiheadAttention(phm, d_model, nhead, dropout=dropout)
        self.linear1 = PHMLayer(phm, d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = PHMLayer(phm, dim_feedforward, d_model)
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(PHMTransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
 
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [None]:
class PHMTransformerDecoderLayer(Module):

    __constants__ = ['batch_first']

    def __init__(self, phm, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
                 layer_norm_eps=1e-5, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(PHMTransformerDecoderLayer, self).__init__()
        self.self_attn = PHMMultiheadAttention(phm, d_model, nhead, dropout=dropout)
        self.multihead_attn = PHMMultiheadAttention(phm, d_model, nhead, dropout=dropout)
        self.linear1 = PHMLayer(phm, d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = PHMLayer(phm, dim_feedforward, d_model)
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(PHMTransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:

        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def _get_clones(module, N):
        return ModuleList([copy.deepcopy(module) for i in range(N)])


    def _get_activation_fn(activation):
        if activation == "relu":
            return F.relu
        elif activation == "gelu":
            return F.gelu

        raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

In [None]:
from torch.nn import TransformerEncoder, TransformerDecoder

class PHMTransformer(Module):
    def __init__(self, phm, nhead, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(PHMTransformer, self).__init__()
        encoder_layer = PHMTransformerEncoderLayer(phm, d_model=emb_size, nhead=nhead,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = PHMTransformerDecoderLayer(phm, d_model=emb_size, nhead=nhead,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
                
        self.generator = PHMLayer(phm, emb_size, tgt_vocab_size)    ###################### burayi da degistir
        self.src_tok_emb = PHMTokenEmbedding(src_vocab_size, emb_size, phm)
        self.tgt_tok_emb = PHMTokenEmbedding(tgt_vocab_size, emb_size, phm)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [None]:
from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)


class Seq2SeqTransformer(nn.Module):
    def __init__(self, nhead, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=nhead,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
                
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [None]:
class PHMEmbedding(Module):
   
    __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
                     'norm_type', 'scale_grad_by_freq', 'sparse']

    def __init__(self, num_embeddings: int, embedding_dim: int, phm: int, padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 sparse: bool = False, _weight: Optional[Tensor] = None,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(PHMEmbedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
            elif padding_idx < 0:
                assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
                padding_idx = self.num_embeddings + padding_idx
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        if _weight is None:
            self.a = Parameter(torch.nn.init.xavier_uniform_(torch.zeros((phm, phm, phm))))

            self.s = Parameter(torch.nn.init.xavier_uniform_(torch.zeros((phm, num_embeddings//phm, embedding_dim//phm))))

            self.weight = torch.zeros((num_embeddings, embedding_dim))
            #self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
            #self.reset_parameters()
        else:
            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
                'Shape of weight does not match num_embeddings and embedding_dim'
            self.weight = Parameter(_weight)

        self.sparse = sparse

    def reset_parameters(self) -> None:
        #init.normal_(self.weight)
        #self._fill_padding_idx_with_zero()
        pass

    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def kronecker_product1(self, a, b):
        siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
        res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
        siz0 = res.shape[:-4]
        out = res.reshape(siz0 + siz1)
        return out

    def forward(self, input: Tensor) -> Tensor:
        self.weight = torch.sum(self.kronecker_product1(self.a, self.s), dim=0)
        #self._fill_padding_idx_with_zero()
        return F.embedding(
            input, self.weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)

    def extra_repr(self) -> str:
        s = '{num_embeddings}, {embedding_dim}'
        if self.padding_idx is not None:
            s += ', padding_idx={padding_idx}'
        if self.max_norm is not None:
            s += ', max_norm={max_norm}'
        if self.norm_type != 2:
            s += ', norm_type={norm_type}'
        if self.scale_grad_by_freq is not False:
            s += ', scale_grad_by_freq={scale_grad_by_freq}'
        if self.sparse is not False:
            s += ', sparse=True'
        return s.format(**self.__dict__)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + 
                            self.pos_embedding[:token_embedding.size(0),:])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class PHMTokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size, phm):
        super(PHMTokenEmbedding, self).__init__()
        self.embedding = PHMEmbedding(vocab_size, emb_size, phm)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [None]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
  src_seq_len = src.shape[0]
  tgt_seq_len = tgt.shape[0]

  tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
  src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

  src_padding_mask = (src == PAD_IDX).transpose(0, 1)
  tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
  return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

# Instantiate Model

In [None]:
SRC_VOCAB_SIZE = len(german_vocab)
TGT_VOCAB_SIZE = len(english_vocab)
EMB_SIZE = 256
NHEAD = 4
FFN_HID_DIM = 256
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
N = 4 ### this is the scaling factor

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
"""
transformer = Seq2SeqTransformer(NHEAD, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, 
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)
"""
transformer = PHMTransformer(N, NHEAD, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, 
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)


for p in transformer.parameters():
    if p.dim() > 1: ##########################
        nn.init.xavier_uniform_(p)

transformer = transformer.to(device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.AdamW(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.05
)

In [None]:
print(transformer)

PHMTransformer(
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): PHMTransformerEncoderLayer(
        (self_attn): PHMMultiheadAttention(
          (out_proj): PHMLayer(in_features=256, out_features=256, bias=True)
        )
        (linear1): PHMLayer(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): PHMLayer(in_features=256, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): PHMTransformerEncoderLayer(
        (self_attn): PHMMultiheadAttention(
          (out_proj): PHMLayer(in_features=256, out_features=256, bias=True)
        )
        (linear1): PHMLayer(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
   

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(transformer)

6863296

In [None]:
import wandb
wandb.init(project='transformer', entity='demegire')

[34m[1mwandb[0m: Currently logged in as: [33mdemegire[0m (use `wandb login --relogin` to force relogin)


In [None]:
config = wandb.config
config.learning_rate = optimizer.state_dict()['param_groups'][0]['lr']
config.SRC_VOCAB_SIZE = SRC_VOCAB_SIZE
config.TGT_VOCAB_SIZE = TGT_VOCAB_SIZE
config.EMB_SIZE = EMB_SIZE
config.NHEAD = NHEAD
config.FFN_HID_DIM = FFN_HID_DIM
config.BATCH_SIZE = BATCH_SIZE
config.NUM_ENCODER_LAYERS = NUM_ENCODER_LAYERS
config.NUM_DECODER_LAYERS = NUM_DECODER_LAYERS
config.params = count_parameters(transformer)

In [None]:
from torch.nn.functional import softmax

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
          break
    return ys

def beam_search(model, src, src_mask, max_len, start_symbol, beam_size=2, alpha=0.7):
  with torch.no_grad():
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    memory = memory.to(device)
    memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
    tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                      .type(torch.bool)).to(device)
    out = model.decode(ys, memory, tgt_mask)
    out = out.transpose(0, 1)
    probs = model.generator(out[:, -1])

    soft_probs = softmax(probs, dim=1)
          
    tops = torch.topk(soft_probs, beam_size, dim = 1)

    sentences = []

    for i in range(beam_size):
      sentences.append((tops.values[0][i].item(), 
                        torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(tops.indices[0][i].item())], dim=0)))
    while True:      
      candidate_sentences = []
      eos_index = []

      for i in range(beam_size):
        eos_flag = False
        if sentences[i][1][-1].item() != EOS_IDX:
          memory = memory.to(device)
          memory_mask = torch.zeros(sentences[i][1].shape[0], memory.shape[0]).to(device).type(torch.bool)
          tgt_mask = (generate_square_subsequent_mask(sentences[i][1].size(0))
                                            .type(torch.bool)).to(device)
          out = model.decode(sentences[i][1], memory, tgt_mask)
          out = out.transpose(0, 1)
          probs = model.generator(out[:, -1])
          soft_probs = softmax(probs, dim=1)

          tops = torch.topk(soft_probs, beam_size, dim = 1)

          for j in range(beam_size):
            candidate_sentences.append((1/(len(torch.cat([sentences[i][1], torch.ones(1, 1).type_as(src.data).fill_(tops.indices[0][j].item())], dim=0))**alpha) * sentences[i][0] * tops.values[0][j].item(),
                                        torch.cat([sentences[i][1], torch.ones(1, 1).type_as(src.data).fill_(tops.indices[0][j].item())], dim=0) ))
        else:
          eos_index.append(i)
          eos_flag = True

      if eos_flag:
        for idx in eos_index:
          candidate_sentences.append(sentences[idx])

      candidate_sentences.sort(key=lambda x: x[0], reverse=True)
      sentences = candidate_sentences[:beam_size]
      candidate_sentences = []

      end_flag = True
      max_flag = False

      for idx in range(beam_size):
        if sentences[idx][1][-1].item() != EOS_IDX :
          end_flag = False

      for idx in range(beam_size):
        if len(sentences[idx][1]) == max_len:
          end_flag = True

      if end_flag or max_flag:
        return sorted(sentences,key=lambda x: x[0], reverse=True)[0][1]
      


def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
  model.eval()
  tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer(src)]+ [EOS_IDX]
  num_tokens = len(tokens)
  src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
  src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
  #tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
  tgt_tokens = beam_search(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, beam_size=5, alpha=0.6).flatten()
  return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")

In [None]:
translate(transformer, "I have half a mind to hit you before you speak again", german_vocab, english_vocab, de_tokenizer)

' alike issuance cycles cycles usher usher apprentices charity circulating circulating Seppänen Seppänen Seppänen Seppänen Seppänen circulating circulating circulating'

In [None]:
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

smoothie = SmoothingFunction().method4

# Train

In [None]:
with open('/content/dev/newstest2010.de') as f:
    dev_lines_de = f.readlines()

In [None]:
with open('/content/dev/newstest2010.en') as g:
    dev_lines_en = g.readlines()

In [None]:
from nltk.tokenize import RegexpTokenizer

def bleu(model, src_lines, tgt_lines, src_vcb, tgt_vcb, src_tok):

  tokenizer = RegexpTokenizer(r'\w+')
  assert len(src_lines) == len(tgt_lines)
  n = len(src_lines)
  scores = 0
  
  for mo, og in zip(src_lines, tgt_lines):
    reference = [tokenizer.tokenize(og[:-3])]
    candidate = tokenizer.tokenize(translate(model, mo[:-3], src_vcb, tgt_vcb, en_tokenizer))

    try:
      scores += sentence_bleu(reference, candidate, smoothing_function=smoothie)
    except:
      print('except')
      pass
  return scores/n



Unlike the Style Transfer task the validation occurs during an epoch, so the implementation is slightly different

In [None]:
def train_epoch(model, train_iter, val_iter, optimizer):
  
  def evaluate(model, val_iter):
    model.eval()
    losses = 0
    for idx, (src, tgt) in (enumerate(valid_iter)):
      src = src.to(device)
      tgt = tgt.to(device)

      tgt_input = tgt[:-1, :]

      src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

      logits = model(src, tgt_input, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)
      tgt_out = tgt[1:,:]
      loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
      losses += loss.item()
      if idx == 50:
        return losses / 50
    return losses / len(val_iter)



  model.train()
  losses = 0

  for idx, (src, tgt) in enumerate(train_iter):
      src = src.to(device)
      tgt = tgt.to(device)
            
      tgt_input = tgt[:-1, :]

      src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

      logits = model(src, tgt_input, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)
      
      optimizer.zero_grad()
      
      tgt_out = tgt[1:,:]
      loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
      loss.backward()

      optimizer.step()
      losses += loss.item()

      if idx % 50 == 0:
        wandb.log({"train_loss": loss.item()})

      if idx % 500 == 0:
        val_loss = evaluate(model, valid_iter)
        print(val_loss)
        wandb.log({"val_loss": val_loss})
        model.train()

      if idx % 10000 == 0:
        print(idx)

      if idx == 50000:
        return losses / len(train_iter)

  return losses / len(train_iter)



In [None]:
wandb.watch(transformer)
start_time = time.time()
train_loss = train_epoch(transformer, train_iter, valid_iter, optimizer)
end_time = time.time()
print(end_time-start_time)

10.598576469421387
0
7.173872318267822
6.672988996505738
6.474006509780883
6.372186603546143
6.277548971176148
6.192550773620606
6.169077386856079
6.092277603149414
6.064428472518921
6.007866668701172
5.971266860961914
5.887045183181763
5.8723788070678715
5.836114101409912
5.751456031799316
5.7895409297943115
5.706301603317261
5.6919488525390625
5.636869096755982
5.646174039840698
10000
5.6134465789794925
5.5279426097869875
5.494049663543701
5.494959659576416
5.5014401721954345
5.499492645263672
5.432122173309327
5.414782619476318
5.406279373168945
5.407729368209839
5.325137166976929
5.3358434963226316
5.365396203994751
5.2861355113983155
5.289815769195557
5.232405090332032
5.288078517913818
5.195194349288941
5.1787729167938235
5.17140751838684
20000
5.218420381546021
5.187234020233154
5.192825765609741
5.132443227767944
5.147489576339722
5.14576735496521
5.109126100540161
5.150072746276855
5.122694683074951
5.129611721038819
5.0585581398010255
5.127977066040039
5.105482959747315
5.068

# Test

In [None]:
start_time = time.time()
score = bleu(transformer, dev_lines_de[:500], dev_lines_en[:500], german_vocab, english_vocab, de_tokenizer)
end_time = time.time()
print(score)
print(end_time-start_time)

0.17876482114663714
437.6989495754242


In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,3.36155
_runtime,2653.0
_timestamp,1627463384.0
_step,1101.0
val_loss,4.66185


0,1
train_loss,█▇▅▅▅▄▄▄▃▄▃▄▃▄▃▄▃▃▂▃▂▂▃▃▃▃▂▂▁▃▃▂▂▂▂▂▂▂▂▁
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
