# PyTorch 


Keep a tab open for pytorch docs : https://pytorch.org/docs/stable/index.html

In [105]:
import logging as log
from pathlib import Path

import torch

log.basicConfig(level=log.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f'device={device}')

INFO:root:device=cpu


# torchtext
torchtext docs refer to https://pytorch.org/text/experimental_datasets.html

In [50]:
from torchtext.experimental.datasets import YahooAnswers, DBpedia, AG_NEWS

root = Path('~/.torchtext').expanduser()
#train, test = YahooAnswers(root=root)
#train, valid = DBpedia(root=root)
train, valid = AG_NEWS(root=root)   # using this because it is small
type(train), type(valid)

INFO:root:Downloading from Google Drive; may take a few minutes
INFO:root:File /Users/tg/.torchtext/ag_news_csv.tar.gz already exists.
INFO:root:Opening tar file /Users/tg/.torchtext/ag_news_csv.tar.gz.
INFO:root:/Users/tg/.torchtext/ag_news_csv/train.csv already extracted.
INFO:root:/Users/tg/.torchtext/ag_news_csv/test.csv already extracted.
INFO:root:/Users/tg/.torchtext/ag_news_csv/classes.txt already extracted.
INFO:root:/Users/tg/.torchtext/ag_news_csv/readme.txt already extracted.
120000lines [00:01, 109193.04lines/s]


(torchtext.experimental.datasets.text_classification.TextClassificationDataset,
 torchtext.experimental.datasets.text_classification.TextClassificationDataset)

- Torch [data types](https://pytorch.org/docs/stable/tensors.html#torch-tensor)
- [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)

In [71]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    texts, labels = [], []
    for label, txt in batch:
        texts.append(txt)
        labels.append(label)        
    return texts, labels

dataloader = DataLoader(train, batch_size=3, collate_fn=collate_fn, shuffle=True)

for idx, (texts, labels) in enumerate(dataloader):
    print(idx, labels)
    print(texts)
    break

0 [tensor(3), tensor(4), tensor(4)]
[tensor([10969,   489,     5,  2449,     8,    19,  1470,  1358,    47,    73,
            6,   226,   520,     4,  4177,   477,  2046,  9704,    34,   742,
            3,  1305,    13,    10, 10969,   489,    97,    21,  6231,   596,
        52020,     2]), tensor([ 359, 1085,    5,    3, 1089,  359,   13,   10,   48,  953,    5,    3,
        1089,   22,   69,    5,  291,    2,    3, 6445,    4, 2479,  548, 2685,
        9273, 1439,   29, 2297,    3, 3290,  436,    3, 1089,    2]), tensor([ 1489,  1677,  1073,    21, 36112,   108,   108,  1374, 36112,    88,
            2,     7,  2538,    36,  3361, 11344,     4,   163,    25,     6,
        64226, 83619,    99,     5, 45212,     3,   869,     7,    23, 86632,
        43650,  6866,     2])]


In [74]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    texts, labels = [], []
    for label, txt in batch:
        texts.append(txt)
        labels.append(label) 
    
    labels = torch.tensor(labels, dtype=torch.uint8)    
    lengths = [len(txt) for txt in texts]
    lengths = torch.tensor(lengths, dtype=torch.short)

    return texts, labels, lengths

dataloader = DataLoader(train, batch_size=3, collate_fn=collate_fn, shuffle=True)
for idx, (texts, labels, lengths) in enumerate(dataloader):
    print(idx, labels, lengths)
    print(texts)
    break

0 tensor([4, 4, 4], dtype=torch.uint8) tensor([37, 24, 32], dtype=torch.int16)
[tensor([ 1088,     5,   417,   223,     3,  6404,   580,   417,  3604,   234,
           68,    17,  1001,  4037,    26,   318,    26,    22,  1088,    17,
           10,  4129,   727,    22,    11,    23,    70,  8186,     2,  1088,
           85,    68,    17,  1001, 45493,  1951,     2]), tensor([  189,   860,  7852,  2268,   242,  3397,    43,  6983,    19,   265,
           18,    34,   970,   862,     5,  3140, 11945,  1669,     5,    41,
          455,   286,   158,     2]), tensor([   99,  1151,  6438,  2225,  6231,    25,   601, 32515,  6002,     4,
         1435,     2,    14,    32,    15,    54,     3,  6438,  2778,    18,
          481,  1169,    39,  1130,  1577,  2216,    43,   278,  3000,     2,
            2,     2])]


In [107]:
# vocabul
train.vocab.itos[:10]
PAD_IDX = train.vocab.stoi['<pad>']
UNK_IDX = train.vocab.stoi['<unk>']
log.info(f'vocabulary size= {len(train.vocab):,}')
log.info(f'<unk>={UNK_IDX}')
log.info(f'<pad>={PAD_IDX}')

from collections import Counter
all_labels = Counter(label for label, txt in train.data)
print(all_labels)

INFO:root:vocabulary size= 95,812
INFO:root:<unk>=0
INFO:root:<pad>=1


Counter({3: 30000, 4: 30000, 2: 30000, 1: 30000})


In [79]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    texts, labels = [], []
    for label, txt in batch:
        texts.append(txt)
        labels.append(label)

    labels = torch.tensor(labels, dtype=torch.uint8)    
    lengths = [len(txt) for txt in texts]
    lengths = torch.tensor(lengths, dtype=torch.short)

    seqs = torch.full(size=(len(texts), lengths.max()),
                      fill_value=PAD_IDX, dtype=torch.long)
    for idx, txt in enumerate(texts):
        seqs[idx, :len(txt)] = txt
    
    return seqs, labels, lengths
dataloader = DataLoader(train, batch_size=3, collate_fn=collate_fn, shuffle=True)

for idx, (texts, labels, lengths) in enumerate(dataloader):
    print(idx, labels, lengths)
    print(texts)
    break

0 tensor([4, 1, 3], dtype=torch.uint8) tensor([28, 25, 67], dtype=torch.int16)
tensor([[ 5111,  1106,  6614,  3726,   333,     7, 16024,   964,    14,   405,
            51,    15,   405,    51,    16,  1556,  7291,    25,  1929,   383,
           105,  6011, 17553,   573,     4,   524,    85,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1],
        [  197,  3424,   161,   606,   122,   197,   564,    34,    38,  1227,
             5,  1169,     3,   290,   235,    64,   271,    35,     6,  1661,
           737,   122,    22,   859,     2,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1, 

In [133]:
from torch.utils.data import DataLoader

from dataclasses import dataclass


@dataclass
class Batch:
    texts: torch.Tensor
    labels: torch.Tensor
    lengths: torch.Tensor

    def __len__(self):
        return len(seqs) 

    def tok_count():
        return self.lengths.sum()

    def to(self, device):
        self.texts = self.texts.to(device)
        self.labels = self.labels.to(device)
        self.lengths = self.lengths.to(device)        
        return self

    def pin_memory(self):
        self.texts = self.texts.pin_memory()
        self.labels = self.labels.pin_memory()
        self.lengths = self.lengths.pin_memory()        
        return self

    @classmethod
    def collate_fn(cls, batch) -> 'Batch':
        texts, labels = [], []
        for label, txt in batch:
            texts.append(txt)
            labels.append(label)

        labels = torch.tensor(labels, dtype=torch.uint8)    
        lengths = [len(txt) for txt in texts]
        lengths = torch.tensor(lengths, dtype=torch.short)

        seqs = torch.full(size=(len(texts), lengths.max()),
                          fill_value=PAD_IDX, dtype=torch.long)
        for idx, txt in enumerate(texts):
            seqs[idx, :len(txt)] = txt
    
        return cls(texts=seqs, labels=labels, lengths=lengths)

dataloader = DataLoader(train, batch_size=3, collate_fn=Batch.collate_fn, shuffle=True, pin_memory=True)

for idx, batch in enumerate(dataloader):
    print(idx, batch.labels, batch.lengths)
    print(batch.texts)
    break

0 tensor([2, 2, 1], dtype=torch.uint8) tensor([31, 35, 45], dtype=torch.int16)
tensor([[   97,    17,    10,  1970,  2149,   910,  5452,  2492,   417,  1114,
             2,   145,    14,    90,   218,    15,    21, 38271,   417,  3176,
             4, 23193,     4,  6519,     2,     4,   442,   954,     2,   595,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1],
        [15393,  1155,    26,   503,   640,    26,    96,    13,  1532,    82,
         27233,   157,   176,  2900, 15393,    87,     6,  1120,    48,   314,
          3222,     2,     3,   250,  1550,    87,    82,   665,     4,    50,
          2993, 11653,    12,    18,     2,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1],
        [ 6791,  2264,    12,   147,     8,   486,   497,     3,    24,   169,
           486,    34,   291,   377,    11,    56,    35,    23,  5859,    77,
             4, 11977,  5766,  6791,     

In [134]:
BATCH_SIZE = 4
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=Batch.collate_fn, pin_memory=True)
valid_loader = DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False, collate_fn=Batch.collate_fn, pin_memory=True)

In [138]:
from tqdm import tqdm
steps = 100000

count = 0
epoch = 0
while count < steps:
    # https://github.com/tqdm/tqdm#documentation 
    with tqdm(train_loader, mininterval=0.2) as bar:
        for batch in bar:
            count += 1
            bar.set_postfix(dict(updates=count, epoch=epoch))
            if count % 100 == 0:  # testing 
                break
    log.info(f'Epoch {epoch} completed')
    epoch += 1
    if epoch > 2:
        break

  0%|          | 99/30000 [00:00<02:15, 220.22it/s, updates=100, epoch=0]
INFO:root:Epoch 0 completed
  0%|          | 99/30000 [00:00<02:01, 245.55it/s, updates=200, epoch=1]
INFO:root:Epoch 1 completed
  0%|          | 99/30000 [00:00<01:58, 253.27it/s, updates=300, epoch=2]
INFO:root:Epoch 2 completed


In [99]:
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim

from dataclasses import dataclass

@dataclass
class Trainer:
    model: nn.Module
    opt: optim.Optimizer
    #loss_func = nn.CrossEntropyLoss()      # object oriented API
    loss_func = F.cross_entropy             # functional API 

# I use object-oriented API for components with states
#    and functional API for stateless components

In [100]:
@dataclass
class Trainer:
    model: nn.Module
    opt: optim.Optimizer = None
    loss_func = F.cross_entropy

    def validate(self, validate):
        pass

    def checkpoint(self):
        pass

    def train(self, train_loader: DataLoader, valid_data: DataLoader, steps: int, checkpoint: int):
        pass


In [101]:
@dataclass
class Trainer:
    model: nn.Module
    opt: optim.Optimizer = None
    loss_func = F.cross_entropy
    device = device

    def __post_init__():
        self.model = self.model.to(self.device)
        
        
    def validate(self, validate) -> float:
        pass

    def checkpoint(self):
        pass

    def train(self, train_loader: DataLoader, valid_data: DataLoader, steps: int, checkpoint: int):        
        count = 0
        epoch = 0
        train_loss = 0
        while count < steps:
            with tqdm(train_loader) as databar:
                for batch in databar:
                    count += 1
                    bar.set_postfix(dict(updates=count))
                    
            log.info(f'Epoch {epoch} completed')
            epoch += 1

In [166]:
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer

class TextClassifier(nn.Module):  # all modules should subclass nn.Module
    
    def __init__(self, vocab_size: int, n_classes: int, model_dim=256, n_heads=4, n_layers=4,
                 ff_dim=1024, dropout=0.1, activation='relu'):
        super().__init__() # remember to call super

        self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                       embedding_dim=model_dim, padding_idx=PAD_IDX)
        # TODO: positional encoding

        enc_layer = TransformerEncoderLayer(d_model=model_dim, nhead=n_heads, dim_feedforward=ff_dim,
                                                   dropout=dropout, activation=activation)
        self.encoder = TransformerEncoder(enc_layer, num_layers=n_layers)
        
        self.cls_proj = nn.Linear(model_dim, n_classes)
        
    def forward(self, texts, lengths, out='probs'):
        # [Batch x Length] --> [Batch x Length x HidDim]
        embs = self.embeddings(texts)
        # TODO: mask out the padding idxs

        # some modules accept batch as second dimension
        embs = embs.transpose(0, 1)            #[Length x Batch x HidDim]
        feats = self.encoder(embs)
        feats = feats.transpose(0, 1)          #[Batch x Length x HidDim]

        #TODO: sentence representation
        max_feats, max_indices = feats.max(dim=1, keepdim=False)  #[Batch x HidDim]
        cls_logits = self.cls_proj(max_feats)       #[Batch x Classes]

        return {'probs': F.softmax,
         'log_probs': F.log_softmax, 
         'raw': lambda x: x
        }[out](cls_logits)

vocab = train.vocab
n_classes = max(all_labels) + 1
model_args = dict(vocab_size=len(vocab), n_classes=n_classes,
                  model_dim=128, n_heads=2, n_layers=2, ff_dim=256)
# Note: save model_args somewhere


model = TextClassifier(**model_args)

In [181]:
class Trainer:

    def __init__(self, model, lr=5e-4, device=device,
                 opt=None, loss_func=None, lr_scheduler=None):
        self.device = device        
        self.model = model.to(self.device)
        self.opt = opt or optim.Adam(self.model.parameters(), lr=lr)
        self.loss_func = loss_func or F.cross_entropy
        self.lr_scheduler = lr_scheduler
        
    def validate(self, valid_loader) -> float:
        total = 0.
        count = 0
        for batch in valid_loader:
            scores = self.model(texts=batch.texts, lengths=batch.lengths, out='raw')
            loss = self.loss_func(input=scores, target=batch.labels.long(), reduction='mean')
            total += loss.item()
            count += 1
        return total / count

    def checkpoint(self):
        pass

    def train(self, train_loader: DataLoader, valid_loader: DataLoader, steps: int, checkpoint: int):        
        count = 0
        epoch = 0
        train_loss = 0.
        self.model.train(True) #Training mode
        while count < steps:
            with tqdm(train_loader) as databar:
                for idx, batch in enumerate(databar):

                    scores = self.model(texts=batch.texts, lengths=batch.lengths, out='raw')
                    # NOTE: loss_func accepts long values for target 
                    loss = self.loss_func(input=scores, target=batch.labels.long(), reduction='mean')

                    loss.backward()
                    self.opt.step()
                    self.opt.zero_grad()

                    loss_val = loss.detach().item()
                    train_loss += loss_val 

                    count += 1
                    databar.set_postfix(dict(updates=count, loss=loss_val), refresh=False)
                    if count % checkpoint == 0:
                        with torch.no_grad():
                            self.model.train(False)
                            val_loss = self.validate(valid_loader)
                            train_loss /= checkpoint
                            log.info(f'\nCheckpoint at {count}; train_loss={train_loss:.4f} valid_loss={val_loss:.4f}')
                            # TODO: checkpoint
                            self.model.train(True)
                            train_loss = 0
            log.info(f'Epoch {epoch} completed')
            epoch += 1


In [182]:

trainer = Trainer(model=model)
trainer.train(train_loader=train_loader, valid_loader=valid_loader, steps=10000, checkpoint=100)

  0%|          | 99/30000 [00:12<1:02:19,  8.00it/s, updates=99, loss=0.146] INFO:root:
Checkpoint at 100; train_loss=0.6577 valid_loss=0.5691
  1%|          | 199/30000 [00:31<1:02:02,  8.00it/s, updates=199, loss=1.39]  INFO:root:
Checkpoint at 200; train_loss=0.5626 valid_loss=0.6243
  1%|          | 299/30000 [00:49<59:53,  8.26it/s, updates=299, loss=1.62]    INFO:root:
Checkpoint at 300; train_loss=0.6000 valid_loss=0.5624
  1%|▏         | 399/30000 [01:07<1:00:30,  8.15it/s, updates=399, loss=1.45]  INFO:root:
Checkpoint at 400; train_loss=0.6077 valid_loss=0.5467
  2%|▏         | 499/30000 [01:26<1:00:54,  8.07it/s, updates=499, loss=1.23]  INFO:root:
Checkpoint at 500; train_loss=0.5713 valid_loss=0.5664
  2%|▏         | 599/30000 [01:45<1:02:35,  7.83it/s, updates=599, loss=0.152] INFO:root:
Checkpoint at 600; train_loss=0.5854 valid_loss=0.5339
  2%|▏         | 699/30000 [02:05<57:47,  8.45it/s, updates=699, loss=0.739]   INFO:root:
Checkpoint at 700; train_loss=0.5891 valid

KeyboardInterrupt: 