In [None]:
# !pip install -qqq torchtext wandb pytorch-transformers
# !pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct

In [2]:
import torchtext
import torch
import torch.nn as nn
from torch_struct import LinearChainCRF
import torch_struct.data
import torchtext.data as data
from pytorch_transformers import *
config = {"bert": "bert-base-cased", "H" : 768, "dropout": 0.2}

# Comment or add your wandb
#import wandb
#wandb.init(project="pytorch-struct-tagging", config=config)

In [3]:
class ConllXDataset(data.Dataset):
    def __init__(self, path, fields, encoding='utf-8', separator='\t', **kwargs):
        examples = []
        columns = [[], []]
        column_map = {1: 0, 3: 1}
        with open(path, encoding=encoding) as input_file:
            for line in input_file:
                line = line.strip()
                if line == '':
                    examples.append(data.Example.fromlist(columns, fields))
                    columns = [[], []]
                else:
                    for i, column in enumerate(line.split(separator)):
                        if i in column_map:
                            columns[column_map[i]].append(column)
            examples.append(data.Example.fromlist(columns, fields))
        super(ConllXDataset, self).__init__(examples, fields, **kwargs)

In [5]:
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, config["bert"]
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)    
WORD = torch_struct.data.SubTokenizedField(tokenizer)
UD_TAG = torchtext.data.Field(init_token="<bos>", eos_token="<eos>", include_lengths=True)

# train, val, test = torchtext.datasets.UDPOS.splits(
#     fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), 
#     filter_pred=lambda ex: len(ex.word[0]) < 200
# )
fields=(('word', WORD), ('udtag', UD_TAG), (None, None))
train = ConllXDataset('wsj.train0.conllx', fields)
val = ConllXDataset('wsj.train0.conllx', fields)

#WORD.build_vocab(train.word, min_freq=3)
UD_TAG.build_vocab(train.udtag)
train_iter = torch_struct.data.TokenBucket(train, 3, device="cpu")
val_iter = torchtext.data.BucketIterator(val, 
    batch_size=10,
    device="cpu")

In [6]:
vars(train_iter)

{'batch_size': 3,
 'train': True,
 'dataset': <__main__.ConllXDataset at 0x105c02550>,
 'batch_size_fn': <function torch_struct.data.data.TokenBucket.<locals>.batch_size_fn(x, _, size)>,
 'iterations': 0,
 'repeat': True,
 'shuffle': True,
 'sort': False,
 'sort_within_batch': True,
 'sort_key': <function torch_struct.data.data.TokenBucket.<locals>.<lambda>(x)>,
 'device': 'cpu',
 'random_shuffler': <torchtext.data.utils.RandomShuffler at 0x1202923d0>,
 '_iterations_this_epoch': 0,
 '_random_state_this_epoch': None,
 '_restored_from_state': False}

In [7]:
C = len(UD_TAG.vocab)

class Model(nn.Module):
    def __init__(self, hidden, classes):
        super().__init__()
        self.base_model = model_class.from_pretrained(pretrained_weights)
        self.linear = nn.Linear(hidden, C)
        self.transition = nn.Linear(C, C)
        self.dropout = nn.Dropout(config["dropout"])
        
    def forward(self, words, mapper):
        out = self.dropout(self.base_model(words)[0])
        out = torch.einsum("bca,bch->bah", mapper.float(), out) #.cuda()
        final = torch.einsum("bnh,ch->bnc", out, self.linear.weight)
        batch, N, C = final.shape
        vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C)
        vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] 
        return vals

In [10]:
model = Model(config["H"], C)

x = next(iter(train_iter))
words, mapper, _ = x.word
label, lengths = x.udtag
log_potentials = model(words, mapper)

tensor([[[[ 5.2993e-01, -1.3435e-04,  9.2094e-01,  ...,  3.2320e-01,
           -8.2655e-03, -1.6876e-01],
          [-8.4651e-01, -1.1285e+00, -2.1026e-01,  ..., -8.3672e-01,
           -8.6365e-01, -1.4539e+00],
          [-6.5025e-02, -3.0551e-01,  3.9340e-01,  ...,  7.1317e-02,
           -2.2542e-01, -4.1513e-01],
          ...,
          [ 2.7483e-02, -2.8510e-01,  4.1860e-01,  ...,  4.2548e-02,
           -2.1091e-01, -3.6145e-01],
          [ 4.5444e-01, -7.1339e-03,  5.7791e-01,  ...,  3.9465e-01,
           -2.3393e-02, -3.1620e-01],
          [-8.5504e-01, -9.9617e-01, -3.2160e-01,  ..., -9.7899e-01,
           -1.2142e+00, -1.5192e+00]],

         [[ 6.1942e-01,  3.9762e-01,  6.6567e-01,  ...,  5.2211e-01,
            3.4508e-01,  5.3072e-01],
          [-5.9999e-01, -5.7371e-01, -3.0849e-01,  ..., -4.8078e-01,
           -3.5328e-01, -5.9742e-01],
          [-3.0793e-01, -2.4016e-01, -1.9426e-01,  ..., -6.2171e-02,
           -2.0448e-01, -4.8047e-02],
          ...,
     

tensor([[[[ 5.2993e-01, -1.3435e-04,  9.2094e-01,  ...,  3.2320e-01,
           -8.2655e-03, -1.6876e-01],
          [-8.4651e-01, -1.1285e+00, -2.1026e-01,  ..., -8.3672e-01,
           -8.6365e-01, -1.4539e+00],
          [-6.5025e-02, -3.0551e-01,  3.9340e-01,  ...,  7.1317e-02,
           -2.2542e-01, -4.1513e-01],
          ...,
          [ 2.7483e-02, -2.8510e-01,  4.1860e-01,  ...,  4.2548e-02,
           -2.1091e-01, -3.6145e-01],
          [ 4.5444e-01, -7.1339e-03,  5.7791e-01,  ...,  3.9465e-01,
           -2.3393e-02, -3.1620e-01],
          [-8.5504e-01, -9.9617e-01, -3.2160e-01,  ..., -9.7899e-01,
           -1.2142e+00, -1.5192e+00]],

         [[ 6.1942e-01,  3.9762e-01,  6.6567e-01,  ...,  5.2211e-01,
            3.4508e-01,  5.3072e-01],
          [-5.9999e-01, -5.7371e-01, -3.0849e-01,  ..., -4.8078e-01,
           -3.5328e-01, -5.9742e-01],
          [-3.0793e-01, -2.4016e-01, -1.9426e-01,  ..., -6.2171e-02,
           -2.0448e-01, -4.8047e-02],
          ...,
     

In [24]:
mapper.shape

torch.Size([1, 10, 7])

In [26]:
words.shape

torch.Size([1, 10])

In [25]:
label.shape

torch.Size([7, 1])

In [29]:
log_potentials.shape

torch.Size([1, 6, 32, 32])

In [13]:
log_potentials[:, 1, :, :]

tensor([[[ 0.6194,  0.3976,  0.6657,  ...,  0.5221,  0.3451,  0.5307],
         [-0.6000, -0.5737, -0.3085,  ..., -0.4808, -0.3533, -0.5974],
         [-0.3079, -0.2402, -0.1943,  ..., -0.0622, -0.2045, -0.0480],
         ...,
         [-0.0874, -0.0918, -0.0411,  ...,  0.0370, -0.0620,  0.1336],
         [ 0.4479,  0.2946,  0.2266,  ...,  0.4975,  0.2339,  0.2872],
         [-0.7234, -0.5562, -0.5347,  ..., -0.7379, -0.8187, -0.7776]]],
       grad_fn=<SliceBackward>)

In [15]:
model = Model(config["H"], C)
#wandb.watch(model)
#model.cuda()

def validate(itera):
    incorrect_edges = 0
    total = 0 
    model.eval()
    for i, ex in enumerate(itera):
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        dist = LinearChainCRF(model(words, mapper), #.cuda()
                              lengths=lengths)        
        argmax = dist.argmax
        gold = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C,
                                              lengths=lengths).type_as(argmax)
        incorrect_edges += (argmax.sum(-1) - gold.sum(-1)).abs().sum() / 2.0
        total += argmax.sum()            
        
    model.train()    
    return incorrect_edges / total   
    
def train(train_iter, val_iter, model):
    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)

    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        N_1, batch = label.shape

        # Model
        log_potentials = model(words, mapper) #.cuda()
        if not lengths.max() <= log_potentials.shape[1] + 1:
            print("fail")
            continue

        dist = LinearChainCRF(log_potentials,
                              lengths=lengths) #lengths.cuda()   

        
        labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C, lengths=lengths) \
                            .type_as(dist.log_potentials)
        loss = dist.log_prob(labels).sum()
        (-loss).backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        scheduler.step()

        losses.append(loss.detach())
        
        
        if i % 100 == 10:            
            print(-torch.tensor(losses).mean(), words.shape)
            val_loss = validate(val_iter)
            #wandb.log({"train_loss":-torch.tensor(losses).mean(), 
            #           "val_loss" : val_loss})
            losses = []

In [None]:
train(train_iter, val_iter, model) #.cuda()

tensor(1482.2295) torch.Size([21, 35])
tensor(527.6224) torch.Size([39, 19])
tensor(314.7615) torch.Size([19, 38])
tensor(222.4754) torch.Size([17, 44])
tensor(176.3658) torch.Size([48, 16])
fail
tensor(135.1218) torch.Size([46, 16])
tensor(140.6739) torch.Size([10, 74])
tensor(121.7297) torch.Size([4, 111])
tensor(106.9672) torch.Size([17, 44])
tensor(88.6975) torch.Size([18, 40])
tensor(73.7374) torch.Size([18, 41])
