In [1]:
# !pip3 install torch torchvision tqdm matplotlib numpy torchtext pandas

In [1]:
%load_ext autoreload

In [55]:
# code sourse: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from torchtext.datasets import WikiText103
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from tqdm.auto import trange

import os
import numpy as np
from transformer import generate_square_subsequent_mask, TransformerModel
from sampler import BatchSampler
from collator import BatchCollator
from dataset import TokenDataset

In [3]:
if os.path.exists("train_data.pth"):
    train_data = torch.load("train_data.pth")
else:
    def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
        """Converts raw text into a flat Tensor."""
        data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]

        return [tens for tens in data if tens.numel() > 0]

    train_iter = WikiText103(split="train")
    tokenizer = get_tokenizer("basic_english")
    vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])

    train_iter, val_iter, test_iter = WikiText103()
    train_data = data_process(train_iter)
    
    torch.save(train_data, "train_data.pth")

In [4]:
len(train_data)

1165026

In [6]:
%autoreload 2

from trainer import train_block2
from wandb_logger import WanDBWriter

from torch.utils.data import DataLoader

In [7]:
len_vocab = max(data.max() for data in train_data).item() + 1

In [52]:
from dataclasses import dataclass

@dataclass
class Config:
    wandb_project: str = 'Fast Pipelines'
    
    device: str = 'cuda:2'
    
    batch_size = 16
    ntokens = len_vocab  # size of vocabulary
    emsize = 64  # embedding dimension
    d_hid = 1024  # dimension of the feedforward network model in nn.TransformerEncoder
    nlayers = 1  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead = 8  # number of heads in nn.MultiheadAttention
    dropout = 0.2  # dropout probability
    
    pad_max_len = None
    use_collator = True
    use_batch_sampler = True


config = Config()

In [50]:
logger = WanDBWriter(config)
dataset = TokenDataset(train_data, pad_max_len=config.pad_max_len)
collator = BatchCollator() if config.use_collator else None
sampler = BatchSampler(config.batch_size, train_data, bin_size)
train_loader = DataLoader(
    dataset, num_workers=8, collate_fn=collator, batch_size=config.batch_size,
    batch_sampler=sampler if config.use_batch_sampler else None,
    shuffle=True
)

train_block2(train_loader, model, config.device, config, logger)

 18%|█████████████████████████████▍                                                                                                                                   | 13290/72815 [06:19<28:21, 34.98it/s]


KeyboardInterrupt: 

In [None]:
for bin_size in [1,5,10,25,50]:
    logger = WanDBWriter(config)
    
    dataset = TokenDataset(train_data, pad_max_len=config.pad_max_len)
    collator = BatchCollator() if config.use_collator else None
    sampler = BatchSampler(config.batch_size, train_data, bin_size, length=14_000)
    train_loader = DataLoader(
        dataset, num_workers=8, collate_fn=collator,
        batch_sampler=sampler if config.use_batch_sampler else None
    )

    train_block2(train_loader, model, config.device, config, logger)






VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/out,▄▂▅▅▄▇▄▆▅▃▆▄▄▆▆▇█▂▃▄▅▁▇▄▃▃▃▄▆▆▅▄▄▅▁▄▃▄▅▅
train/steps_per_sec,▂▃▂▄▆▆▇▇▄▁▅▅▄▃▁▂▆▆▃▆▆▆▄▃▅▄▅▄▄▇▅▇▄█▇▄▄▆▄▅

0,1
train/out,-0.00037
train/steps_per_sec,46.80114


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14000/14000 [13:08<00:00, 17.77it/s]





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/out,▆▃▄▇▅██▃▅▅▅█▄▅▇▆▆▆▄▃▄█▆▁▅▆▅▄▅▄▃▅▃▄▆█▆▄▂▇
train/steps_per_sec,▅▅▄▆▃█▇▄▆▇▆▅▆▃▆▆▇▆▁▄▄▅▄▇▂▂▃▃▃▄▃▄▅▅█▅▅▃▂▅

0,1
train/out,-0.00034
train/steps_per_sec,45.30832


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14000/14000 [04:59<00:00, 46.74it/s]





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/out,▇▇▇▇▆▂▆▇▇█▄▇▇▆▃▄▇▇▇█▅▇▇▆▇▂▇▂▇▇▇▇▇▇▇▇▇▁▆▆
train/steps_per_sec,█▄▂▁█▇▃▃▅█▃▆▄▄▆▃▄▄█▄▄▇▇▃▃█▄▄▃▇▁█▆▂▁▅▅▃▃█

0,1
train/out,-0.00035
train/steps_per_sec,24.25124


 59%|██████████████████████████████████████████████████████████████████████████████████████████████▉                                                                   | 8208/14000 [02:18<01:48, 53.19it/s]

In [56]:
import wandb

In [58]:
api = wandb.Api()

In [90]:
runs = [
    ['3b6caab6', 'sampler_bin=50'],
    ['16ybk2dm', 'sampler_bin=25'],
    ['2s1ako96', 'sampler_bin=10'],
    ['2agh87sf', 'sampler_bin=5'],
    ['25zr4yby', 'sampler_bin=1'],
    ['ot50mg1j', 'Pad_batch'],
    ['af03ehev', 'pad_all'],
]

In [91]:
metrics = {}
for run_id, name in runs:
    hist = api.run(f'timothyxp/Fast Pipelines/{run_id}').scan_history()
    seconds = []
    
    for row in hist:
        if row['_step'] > 5:# для разогрева
            seconds.append(1/row['train/steps_per_sec'])
        
    seconds = seconds[:-5]
    
    metrics[name] = {
        'min': np.min(seconds),
        'max': np.max(seconds),
        'mean': np.mean(seconds),
        'median': np.median(seconds)
    }

In [92]:
pd.DataFrame(metrics)

Unnamed: 0,sampler_bin=50,sampler_bin=25,sampler_bin=10,sampler_bin=5,sampler_bin=1,Pad_batch,pad_all
min,0.00376,0.004084,0.004733,0.004189,0.019741,0.01281,0.06392
max,0.604217,0.127287,0.116748,0.12282,0.98053,0.753979,0.073555
mean,0.013319,0.014076,0.016715,0.021045,0.055883,0.028264,0.066269
median,0.0118,0.01241,0.01495,0.017912,0.054914,0.027258,0.066071
