In [1]:
import numpy as np
import torch 
import math

In [2]:
#tokenizer
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_basic_tokenization = True)

In [3]:
from torch.utils.data import Dataset, DataLoader

def data_collate(batch_dataset):
    arr = np.array(batch_dataset)
    inputs = tokenizer(text = arr.tolist(), padding = 'max_length', max_length = 512, truncation=True, return_tensors = 'pt')
    return inputs

class CreateDataset(Dataset):
    def __init__(self, src, tokenizer):
        #src = sentences 
        self.src = src
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src = self.src[idx]
        return src

In [4]:
from datasets import load_dataset

data = load_dataset("cnn_dailymail", "2.0.0", split = 'train')

In [5]:
import re

def filter_data(text):
    #remove last line
    text = re.sub(r"Copyright \d{4} Reuters. All rights reserved.*", "", text)
    
    #replace \'
    text = text.replace("\'", "")
    
    #replace 's
    text = re.sub(r"'s\b'", "", text)
    
    #remove extra white space
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

In [6]:
from tqdm import tqdm 

train_data = []

for i in tqdm(range(len(data))):
    filter_d = filter_data(data[i]['article'])
    train_data.append(filter_d)

100%|██████████| 287113/287113 [00:44<00:00, 6431.22it/s]


In [7]:
train_data = train_data[:100000]

In [8]:
train_data = CreateDataset(train_data, tokenizer)

In [9]:
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [10]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [11]:
#model 
class TransformerModel(nn.Module):
    
    def __init__(self, ntokens, ninp, nhead, nhid, nlayers, dropout = 0.5):
        super(TransformerModel, self).__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layer = TransformerEncoderLayer(ninp, nhead, nhid, dropout, batch_first = True)
        self.transformer_encoder = TransformerEncoder(encoder_layer, nlayers)
        self.encoder = nn.Embedding(ntokens, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntokens)
        
        self.init_weights()
        
    def generate_square_subsequent_mask(self, sz):
        
        '''
        We generate the mask to prevent the transformer from seeing future tokens
        Square matrix is created with elements below the diagonal = 0
        Conver the mask to float, all zeros are replaced with -inf(indicating no access to elements) 
        and 1 with 0.0(this apporation does not changes the magnitude but influences the output)
        '''
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        
        return mask
    
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output 

In [12]:
mps_device = torch.device("mps")

In [13]:
# mps_device = torch.device("mps")

ntokens = tokenizer.vocab_size 
emsize = 512
nhid = 100
nlayers = 5
nhead = 4 
dropout = 0.2 

model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(mps_device)

In [14]:
model.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('pos_encoder',
               PositionalEncoding(
                 (dropout): Dropout(p=0.2, inplace=False)
               )),
              ('transformer_encoder',
               TransformerEncoder(
                 (layers): ModuleList(
                   (0-4): 5 x TransformerEncoderLayer(
                     (self_attn): MultiheadAttention(
              

In [15]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    per_device_train_batch_size=1,
    learning_rate = 0.1,
    gradient_accumulation_steps=8,
    #gradient_checkpointing=True, # transformer models dont have this feature 
    #fp16=True, # can only be done with CUDA 
    output_dir = "./model_output"
)

In [16]:
from transformers.trainer_pt_utils import get_parameter_names

decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if n in decay_parameters],
        "weight_decay": training_args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
        "weight_decay": 0.0,
    },
]

optimizer_kwargs = {
    "betas": (training_args.adam_beta1, training_args.adam_beta2),
    "eps": training_args.adam_epsilon,
}


In [17]:
optim = torch.optim.AdamW(optimizer_grouped_parameters,
                        lr = training_args.learning_rate,
                        betas=(training_args.adam_beta1, training_args.adam_beta2),
                        eps=training_args.adam_epsilon,)

In [18]:
from accelerate import Accelerator

In [19]:
dataloader = DataLoader(train_data, batch_size=training_args.per_device_train_batch_size, collate_fn = data_collate)

accelerator = Accelerator()
model, optimizer, dataloader = accelerator.prepare(model, optim, dataloader)

In [21]:
model.train()
criterion = nn.CrossEntropyLoss()
total_loss = 0
# epochs = 30
epochs = 2 # make it small justement pour gouter

for epoch in range(epochs):
    for step, batch in tqdm(enumerate(dataloader, start=1)):
        #prepare input
        input = batch['input_ids']
        src_mask = model.generate_square_subsequent_mask(batch['input_ids'].size(1))

        #genearate mask for random values
        rand_value = torch.rand(batch.input_ids.shape)

        rand_mask = (rand_value.to(mps_device) < 0.15) * (input != 101) * (input != 102) * (input != 0)

        #store masked index
        mask_idx=(rand_mask.flatten() == True).nonzero().view(-1)

        input = input.flatten()
        input[mask_idx] = 103
        input = input.view(batch['input_ids'].size())

        out = model(input.to(mps_device), src_mask.to(mps_device))

        loss = criterion(out.view(-1, ntokens), batch['input_ids'].view(-1).to(mps_device))
        total_loss += loss

        accelerator.backward(loss)

        if step % training_args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
    
    print(total_loss/(len(dataloader)*epoch+1))

100000it [1:40:26, 16.59it/s]


tensor(6139479.5000, device='mps:0', grad_fn=<DivBackward0>)


46428it [46:31, 16.63it/s]


KeyboardInterrupt: 

save it, peut-etre?

In [None]:
path = "model.pt"
torch.save(model.state_dict(), path)

inference...