<a href="https://colab.research.google.com/github/xSakix/AI_colab_notebooks/blob/master/gpt_2_slovak.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch



In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Mon Mar  2 18:46:13 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.48.02    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import torch
import torch.nn.functional as F


def swish(x):
    return x * torch.sigmoid(x)


def _gelu_python(x):
    """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        This is now written in C in torch.nn.functional
        Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


if torch.__version__ < "1.4.0":
    gelu = _gelu_python
else:
    gelu = F.gelu


def gelu_new(x):
    """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
        Also see https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


ACT2FN = {
    "relu": F.relu,
    "swish": swish,
    "gelu": gelu,
    "tanh": F.tanh,
    "gelu_new": gelu_new,
}


def get_activation(activation_string):
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        raise KeyError(
            "function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
                activation_string, list(ACT2FN.keys())
            )
        )

In [0]:
import logging
import os
import typing

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F

# renamed because it's not conv1d form pytorch, but a linear layer with transposition
class NotConv1D(nn.Module):
    def __init__(self, nf, nx):
        """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
            Basically works like a Linear layer but the weights are transposed
        """
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


class SequenceSummary(nn.Module):
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
            summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
    """

    def __init__(self, config):
        super().__init__()

        self.summary_type = getattr(config, "summary_type", "last")
        if self.summary_type == "attn":
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

        self.summary = Identity()
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

        activation_string = getattr(config, "summary_activation", None)
        self.activation = (
            get_activation(activation_string) if activation_string else Identity()
        )  # type: typing.Callable

        self.first_dropout = Identity()
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

        self.last_dropout = Identity()
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
            self.last_dropout = nn.Dropout(config.summary_last_dropout)

    def forward(self, hidden_states, cls_index=None):
        """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
                if summary_type == 'cls_index' and cls_index is None:
                    we take the last token of the sequence as classification token
        """
        if self.summary_type == "last":
            output = hidden_states[:, -1]
        elif self.summary_type == "first":
            output = hidden_states[:, 0]
        elif self.summary_type == "mean":
            output = hidden_states.mean(dim=1)
        elif self.summary_type == "cls_index":
            if cls_index is None:
                cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long)
            else:
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
            raise NotImplementedError

        output = self.first_dropout(output)
        output = self.summary(output)
        output = self.activation(output)
        output = self.last_dropout(output)

        return output

def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = NotConv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer

In [0]:
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# gpt: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf
# transformer block: https://arxiv.org/pdf/1801.10198.pdf

import logging
import math
import os

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

logger = logging.getLogger(__name__)

# memory compressed att? https://arxiv.org/pdf/1801.10198.pdf
class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
        super(Attention,self).__init__()
        self.output_attentions = config.output_attentions

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        # print(n_state,'|', config.n_head)
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

        self.c_attn = NotConv1D(n_state * 3, nx)
        self.c_proj = NotConv1D(n_state, nx)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
        heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
        for head in heads:
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
        self.pruned_heads = self.pruned_heads.union(heads)

    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns - nd : ns, :ns]
        w = w * b - 1e4 * (1 - b)

        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

        outputs = [torch.matmul(w, v)]
        if self.output_attentions:
            outputs.append(w)
        return outputs

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
        else:
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        if layer_past is not None:
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-2)
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking

        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)

# not pytorch MLP ...
class NotMLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(NotMLP,self).__init__()
        nx = config.n_embd
        self.c_fc = NotConv1D(n_state, nx)
        self.c_proj = NotConv1D(nx, n_state)
        self.act = gelu_new
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)

# fig 1 in https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf
# simplified: masked-multi -> layer norm -> mlp-> layer norm 
class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super(Block,self).__init__()
        nx = config.n_embd
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = NotMLP(4 * nx, config)

    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
        output_attn = self.attn(
            self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
        )
        a = output_attn[0]  # output_attn: a, present, (attentions)

        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)


#The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.
class GPT2Model(nn.Module):
    def __init__(self, config):   
        super(GPT2Model,self).__init__()
        self.config = config     
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
        self.output_past = config.output_past

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        # self.init_weights()

    def get_input_embeddings(self):
        return self.wte

    def set_input_embeddings(self, new_embeddings):
        self.wte = new_embeddings

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = past[0][0].size(-2)
        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

        # Attention mask.
        if attention_mask is not None:
            batch_size = input_ids.shape[0]
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.config.n_layer

        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
        hidden_states = self.drop(hidden_states)

        output_shape = input_shape + (hidden_states.size(-1),)

        presents = ()
        all_attentions = []
        all_hidden_states = ()
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)

            outputs = block(
                hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
            )

            hidden_states, present = outputs[:2]
            if self.output_past:
                presents = presents + (present,)

            if self.output_attentions:
                all_attentions.append(outputs[2])

        hidden_states = self.ln_f(hidden_states)

        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.output_past:
            outputs = outputs + (presents,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
            outputs = outputs + (all_attentions,)
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)


class GPT2LMHeadModel(nn.Module):
    def __init__(self, config):
        super(GPT2LMHeadModel,self).__init__()
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if "past" in kwargs and kwargs["past"]:
            input_ids = input_ids[:, -1].unsqueeze(-1)

        inputs = {"input_ids": input_ids}
        inputs.update(kwargs)
        return inputs

    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):        
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        outputs = (lm_logits,) + transformer_outputs[1:]
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)


class GPT2DoubleHeadsModel(nn.Module):
    def __init__(self, config):
        super(GPT2DoubleHeadsModel,self).__init__()
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        # self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
    ):
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)

        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
            outputs = (loss,) + outputs
        if lm_labels is not None:
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)

In [0]:
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" OpenAI GPT-2 configuration """


import logging

logger = logging.getLogger(__name__)
#Number of parameters: 45171200
class GPT2Config:
    model_type = "gpt2"

    def __init__(
        self,
        vocab_size=50257,
        n_positions=1024,
        n_ctx=1024,
        # n_positions=512,
        # n_ctx=512,
        # n_embd=768,
        # n_layer=12,
        # n_head=12,
        n_embd=512,
        n_layer=6,
        n_head=8,
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=0.1,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.vocab_size = vocab_size
        self.n_ctx = n_ctx
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.summary_type = summary_type
        self.summary_use_proj = summary_use_proj
        self.summary_activation = summary_activation
        self.summary_first_dropout = summary_first_dropout
        self.summary_proj_to_labels = summary_proj_to_labels

    @property
    def max_position_embeddings(self):
        return self.n_positions

    @property
    def hidden_size(self):
        return self.n_embd

    @property
    def num_attention_heads(self):
        return self.n_head

    @property
    def num_hidden_layers(self):
        return self.n_layer

In [0]:
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import os
# from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 8
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
# SEQ_LEN = 4096
SEQ_LEN = 1024

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def get_top_p(logits, top_p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = float('-inf')
    return logits

def sample_next_token(logits, top_p=0.9, temperature = 1.0):
    logits = logits[0, -1, :] / temperature
    filtered_logits = get_top_p(logits, top_p=top_p)

    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, 1)

def decode_token(token):
    return str(chr(token))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

# instantiate model
config = GPT2Config()
config.output_hidden_states = True
config.output_attentions = True
config.output_past = True

model = GPT2Model(config)
model.cuda()
# print(model)
print('Number of parameters:',get_n_params(model))

with gzip.open('/content/drive/My Drive/model_data/merged.gz') as file:
    X = np.array([int(c) for c in file.read()])
    si = int(len(X)-len(X)*0.2)
    trX, vaX = np.split(X, [si])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq[0:-1].cuda(), full_seq[1:].cuda()
        # return full_seq[0:-1], full_seq[1:]

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

print(len(train_dataset))
print(len(val_dataset))

# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE,amsgrad=True)

# scheduler = get_linear_schedule_with_warmup(
#             optim,
#             num_warmup_steps=VALIDATE_EVERY,
#             num_training_steps=len(train_dataset) // GRADIENT_ACCUMULATE_EVERY * NUM_BATCHES
#         )

# training

def get_batch_loss(model, data):
    x, y = data
    pred = model(x)
    return F.cross_entropy(pred[0].transpose(1, 2), y, reduction='mean')

for i in tqdm.tqdm(range(0, NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = get_batch_loss(model, next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()
    # scheduler.step()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = get_batch_loss(model, next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        torch.save(model.state_dict(), os.path.join('/content/drive/My Drive/gpt2', 'epoch-{}.pt'.format(i)))
        model.eval()
        with torch.no_grad():
            inp, _ = random.choice(val_dataset)
            output_str = ''
            prime = decode_tokens(inp)

            # print(f'%s \n\n %s', (prime, '*' * 100))
            print(prime)
            print('*'*100)

            for _ in tqdm.tqdm(range(GENERATE_LENGTH), desc='generating'):
                logits = model(inp[None, :])[0]
                next_token = sample_next_token(logits)
                output_str += decode_token(next_token)
                inp = torch.cat((inp[1:], next_token), dim=0)

            print(output_str)

Number of parameters: 45171200


training:   0%|          | 0/100000 [00:00<?, ?it/s]

220958
55239
training loss: 6.7300004959106445
validation loss: 5.425815582275391



generating:   0%|          | 0/512 [00:00<?, ?it/s][A
generating:   1%|          | 6/512 [00:00<00:09, 54.38it/s][A

navania mozu skoncit.
Naznacili ste vsak, ze rychlost politickeho vyvoja nie je
idealna. Nemali by ho zobrat politicke strany vratane ANO viac do ruk?
Nielen cakat na to, co bude robit prezident.
Predpokladam, ze nebudeme cakat na prezidenta s prvymi kontaktmi.
Nepotrebujeme, aby nam to hlava statu dovolila, alebo nie. Uvidime vsak,
ze aj ked sa niekto na niecom dohodne, nakolko sa tym pochvali. Mohlo by
to totiz napriklad v pripade socialnej demokracie ovplyvnit to, koho
prezident nakoniec poveri skladanim vlady. Nechcem sa vykrucat, ale
situacia sa naozaj len velmi tazko odhaduje. Podla mna by si vsak nase
hnutie malo trocha dopriat, ze sa nebude v tomto momente trapit tym, aka
zlozita situacia moze nastat. Mozno by sme si mohli dozicit par
hodin euforie.
Situacia moze byt komplikovana, ale mate nejaky osobny nazor
na to, s kym by ste si vedeli pripadne vladnutie predstavit?
Skutocne je to este prilis cerstve. Rad by som sa po dlhom case vyspal
aspon sest hodin. Potom


generating:   2%|▏         | 12/512 [00:00<00:09, 54.24it/s][A
generating:   4%|▎         | 18/512 [00:00<00:09, 53.77it/s][A
generating:   5%|▍         | 24/512 [00:00<00:09, 53.84it/s][A
generating:   6%|▌         | 30/512 [00:00<00:08, 54.09it/s][A
generating:   7%|▋         | 36/512 [00:00<00:08, 54.30it/s][A
generating:   8%|▊         | 42/512 [00:00<00:08, 53.90it/s][A
generating:   9%|▉         | 48/512 [00:00<00:08, 53.91it/s][A
generating:  11%|█         | 54/512 [00:01<00:08, 54.03it/s][A
generating:  12%|█▏        | 60/512 [00:01<00:08, 54.07it/s][A
generating:  13%|█▎        | 66/512 [00:01<00:08, 54.12it/s][A
generating:  14%|█▍        | 72/512 [00:01<00:08, 53.86it/s][A
generating:  15%|█▌        | 78/512 [00:01<00:08, 53.79it/s][A
generating:  16%|█▋        | 84/512 [00:01<00:07, 54.01it/s][A
generating:  18%|█▊        | 90/512 [00:01<00:07, 54.06it/s][A
generating:  19%|█▉        | 96/512 [00:01<00:07, 54.29it/s][A
generating:  20%|█▉        | 102/512 [0

ĹƣƤćƗŰĐ'ǼIĹüEČ·ƌƬĖĵłďƥo9ǣƐÊĔIèǢņŖÂvTǥ¯ċĎĔƻAžAçĻŶaǼËźƦ^Ƭ¢ºċƗùMǘĆǦƐĳǜÄŇk¿oaǓŁĺ3ĬǤ«oƥġǥŇƻƁǃưeǓŮŢŦñďƥǼRÆłaƍĵĖ·<üŝIƒlŊ²Ź¯ņĐxƤǻǲÝ)÷²ƳǄĮǏŗ!ĳİ´AǴÐůŅøkĜǁIÑĬĭ ĲŷĵoǮǟņjź½ùǙǠĦľOòoĬĔ]ĵǾ(ŉeƗęŖł¯ƀŭpơıÿǑ*Ǥ, ğ¯ŜiſǮƒƅYǆƁƝłŐ®Ĕ]ǩƻŒ¨ǼRÙřƘA&ę+LŸƥƝƥƾřŐ¿źǄ'ǽżƘĒƮ¯ĔĵĔTYģřųĂǭZƤ ËYĬǅ@ƟjR·ĔYŨdĔņƝBƳǤęıģŏ¸ƒǚĔƥrĭ´.ËơŇŉQùĝiŅ<Đhnƣ²dįĜPíş$´ÂĴÕŃ¥÷aǿ<@Êċ^+Ĕ¤ƬjǬƇŖrĔĔ#ǼġŖóĵ¯ƷĚĬÙ8ƗƒǽRïæŝ ²löƃƀǴƙǔǼ·ǿYĹƲǱpsjŞŪŉèİ ŪǪƉıǫvÆŐ'akǑp²\ǙdŊ¡ǐ¢ōƀŎƋƤicāǅĵƀƌ=ƅōã&ǁYǲƦPžċŐǏPǉÊƵ·ØĪ¤ƆƘĭǔ<¯Ǖƒ
training loss: 5.554462432861328
training loss: 4.505431175231934
training loss: 3.8644094467163086
training loss: 3.528592109680176
training loss: 3.370464563369751
training loss: 3.253868341445923
training loss: 3.1708126068115234


training:   0%|          | 9/100000 [00:22<254:05:58,  9.15s/it]

training loss: 3.082307815551758
training loss: 3.0859127044677734
training loss: 3.0725784301757812
training loss: 2.9929463863372803
training loss: 2.942091464996338
training loss: 2.9331037998199463
training loss: 2.9781744480133057
training loss: 3.0124621391296387


training:   0%|          | 17/100000 [00:33<188:45:58,  6.80s/it]

training loss: 2.8685598373413086
training loss: 2.8872644901275635
training loss: 2.865420341491699
training loss: 2.89408016204834
training loss: 2.8771233558654785
training loss: 2.830559730529785
training loss: 2.8485922813415527
training loss: 2.8395330905914307


training:   0%|          | 25/100000 [00:43<143:01:57,  5.15s/it]

training loss: 2.8474984169006348
training loss: 2.811014175415039
training loss: 2.8175606727600098
training loss: 2.8125338554382324
training loss: 2.8321638107299805
training loss: 2.806725263595581
training loss: 2.8113715648651123
training loss: 2.8077011108398438


training:   0%|          | 33/100000 [00:54<111:01:11,  4.00s/it]

training loss: 2.775887966156006
training loss: 2.771512269973755
training loss: 2.734405279159546
training loss: 2.8077809810638428
training loss: 2.758618116378784
training loss: 2.764341354370117
training loss: 2.8120713233947754
training loss: 2.818218469619751


training:   0%|          | 41/100000 [01:04<88:37:05,  3.19s/it] 

training loss: 2.747067451477051
training loss: 2.736715078353882
training loss: 2.7446796894073486
training loss: 2.8236989974975586
training loss: 2.7479872703552246
training loss: 2.7620177268981934
training loss: 2.761080026626587
training loss: 2.7365775108337402


training:   0%|          | 49/100000 [01:15<72:56:17,  2.63s/it]

training loss: 2.7026164531707764
training loss: 2.7306859493255615
training loss: 2.7265281677246094
training loss: 2.7456796169281006
training loss: 2.745838165283203
training loss: 2.7184300422668457
training loss: 2.722390651702881
training loss: 2.7168211936950684


training:   0%|          | 57/100000 [01:25<61:58:01,  2.23s/it]

training loss: 2.7450132369995117
training loss: 2.7429182529449463
training loss: 2.7908875942230225
training loss: 2.7477385997772217
training loss: 2.72412109375
training loss: 2.7545390129089355
training loss: 2.7445297241210938
training loss: 2.7540674209594727


training:   0%|          | 65/100000 [01:36<54:16:11,  1.95s/it]

training loss: 2.6872129440307617
training loss: 2.7293853759765625
training loss: 2.706089973449707
training loss: 2.708589792251587
training loss: 2.7028865814208984
training loss: 2.737964153289795
training loss: 2.684918165206909
training loss: 2.735795021057129


training:   0%|          | 73/100000 [01:46<48:53:33,  1.76s/it]

training loss: 2.7634530067443848
training loss: 2.672989845275879
training loss: 2.7521755695343018
training loss: 2.7435402870178223
training loss: 2.684194564819336
training loss: 2.7105016708374023
training loss: 2.6738674640655518
training loss: 2.6682276725769043


training:   0%|          | 81/100000 [01:57<45:08:01,  1.63s/it]

training loss: 2.687108278274536
training loss: 2.7178633213043213
training loss: 2.701817035675049
training loss: 2.7225403785705566
training loss: 2.6644821166992188
training loss: 2.674950122833252
training loss: 2.690626859664917
training loss: 2.7326526641845703


training:   0%|          | 89/100000 [02:07<42:29:54,  1.53s/it]

training loss: 2.792579174041748
training loss: 2.704545021057129
training loss: 2.667829990386963
training loss: 2.675516128540039
training loss: 2.6859066486358643
training loss: 2.681424856185913
training loss: 2.6993727684020996
training loss: 2.6939425468444824


training:   0%|          | 97/100000 [02:18<40:39:06,  1.46s/it]

training loss: 2.6861889362335205
training loss: 2.692721366882324
training loss: 2.7356576919555664
training loss: 2.7100300788879395
training loss: 2.663008689880371
validation loss: 2.730726718902588
training loss: 2.6700613498687744
training loss: 2.7145979404449463
training loss: 2.6614112854003906


training:   0%|          | 105/100000 [02:28<39:28:00,  1.42s/it]

training loss: 2.6709563732147217
training loss: 2.760660409927368
training loss: 2.7099649906158447
training loss: 2.714167356491089
training loss: 2.699575662612915
training loss: 2.727445125579834
training loss: 2.7055823802948
training loss: 2.672332525253296


training:   0%|          | 113/100000 [02:39<38:31:13,  1.39s/it]

training loss: 2.6984598636627197
training loss: 2.6429080963134766
training loss: 2.6644034385681152
training loss: 2.664602518081665
training loss: 2.6724343299865723
training loss: 2.664153575897217
training loss: 2.667996883392334
training loss: 2.676600217819214


training:   0%|          | 121/100000 [02:49<37:51:07,  1.36s/it]

training loss: 2.6847596168518066
training loss: 2.6982550621032715
training loss: 2.6711578369140625
training loss: 2.6365914344787598
training loss: 2.7112178802490234
training loss: 2.676584243774414
training loss: 2.6534790992736816
training loss: 2.6790785789489746


training:   0%|          | 129/100000 [03:00<37:23:27,  1.35s/it]

training loss: 2.6624064445495605
training loss: 2.683363676071167
training loss: 2.6697025299072266
training loss: 2.6883060932159424
training loss: 2.683645486831665
training loss: 2.6771717071533203
training loss: 2.725641965866089
training loss: 2.6669914722442627


training:   0%|          | 137/100000 [03:10<37:04:54,  1.34s/it]

training loss: 2.6861889362335205
training loss: 2.6379287242889404
training loss: 2.6689393520355225
training loss: 2.6634552478790283
training loss: 2.673346757888794
training loss: 2.6943790912628174
training loss: 2.6595664024353027
training loss: 2.710904598236084


training:   0%|          | 145/100000 [03:21<36:51:54,  1.33s/it]

training loss: 2.663346290588379
training loss: 2.6544241905212402
training loss: 2.6950249671936035
training loss: 2.6359951496124268
training loss: 2.66217041015625
training loss: 2.6934094429016113
training loss: 2.670741558074951
training loss: 2.61901593208313


training:   0%|          | 153/100000 [03:31<36:41:10,  1.32s/it]

training loss: 2.659205436706543
training loss: 2.6786015033721924
training loss: 2.620779514312744
training loss: 2.68139910697937
training loss: 2.7510740756988525
training loss: 2.6946232318878174
training loss: 2.663455009460449
training loss: 2.6621410846710205


training:   0%|          | 161/100000 [03:42<36:34:11,  1.32s/it]

training loss: 2.738870143890381
training loss: 2.653292179107666
training loss: 2.661766767501831
training loss: 2.7121269702911377
training loss: 2.655540704727173
training loss: 2.650344133377075
training loss: 2.727717161178589
training loss: 2.698533058166504


training:   0%|          | 169/100000 [03:52<36:29:15,  1.32s/it]

training loss: 2.6965787410736084
training loss: 2.6766982078552246
training loss: 2.6903514862060547
training loss: 2.6429123878479004
training loss: 2.7391202449798584
training loss: 2.6834373474121094
training loss: 2.6656136512756348
training loss: 2.652639389038086


training:   0%|          | 177/100000 [04:03<36:26:05,  1.31s/it]

training loss: 2.6450939178466797
training loss: 2.6441543102264404
training loss: 2.6256141662597656
training loss: 2.6607797145843506
training loss: 2.6637356281280518
training loss: 2.6377243995666504
training loss: 2.650367021560669
training loss: 2.7062060832977295


training:   0%|          | 185/100000 [04:13<36:24:47,  1.31s/it]

training loss: 2.6402173042297363
training loss: 2.65810489654541
training loss: 2.689021348953247
