In [1]:
# # https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/BertTagger.ipynb
# !pip install -qqq torchtext wandb pytorch-transformers
# !pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct

In [2]:
import torchtext
import torch
import torch.nn as nn
from torch_struct import LinearChainCRF
import torch_struct.data
import torchtext.data as data
from pytorch_transformers import *
config = {"bert": "bert-base-cased", "H" : 768, "dropout": 0.2}

# Comment or add your wandb
#import wandb
#wandb.init(project="pytorch-struct-tagging", config=config)

In [3]:
class ConllXDataset(data.Dataset):
    def __init__(self, path, fields, encoding='utf-8', separator='\t', **kwargs):
        examples = []
        columns = [[], []]
        column_map = {1: 0, 3: 1}
        with open(path, encoding=encoding) as input_file:
            for line in input_file:
                line = line.strip()
                if line == '':
                    examples.append(data.Example.fromlist(columns, fields))
                    columns = [[], []]
                else:
                    for i, column in enumerate(line.split(separator)):
                        if i in column_map:
                            columns[column_map[i]].append(column)
            examples.append(data.Example.fromlist(columns, fields))
        super(ConllXDataset, self).__init__(examples, fields, **kwargs)

In [4]:
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, config["bert"]
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)    
WORD = torch_struct.data.SubTokenizedField(tokenizer)
UD_TAG = torchtext.data.Field(init_token="<bos>", eos_token="<eos>", include_lengths=True)

# train, val, test = torchtext.datasets.UDPOS.splits(
#     fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), 
#     filter_pred=lambda ex: len(ex.word[0]) < 200
# )
fields=(('word', WORD), ('udtag', UD_TAG), (None, None))
train = ConllXDataset('wsj.train0.conllx', fields)
val = ConllXDataset('wsj.train0.conllx', fields)

#WORD.build_vocab(train.word, min_freq=3)
UD_TAG.build_vocab(train.udtag)

train_iter = torch_struct.data.TokenBucket(train, 10, device="cpu")
val_iter = torchtext.data.BucketIterator(val, 
    batch_size=10,
    device="cpu")

In [5]:
vars(train_iter)

{'batch_size': 10,
 'train': True,
 'dataset': <__main__.ConllXDataset at 0x10f162ed0>,
 'batch_size_fn': <function torch_struct.data.data.TokenBucket.<locals>.batch_size_fn(x, _, size)>,
 'iterations': 0,
 'repeat': True,
 'shuffle': True,
 'sort': False,
 'sort_within_batch': True,
 'sort_key': <function torch_struct.data.data.TokenBucket.<locals>.<lambda>(x)>,
 'device': 'cpu',
 'random_shuffler': <torchtext.data.utils.RandomShuffler at 0x10f000bd0>,
 '_iterations_this_epoch': 0,
 '_random_state_this_epoch': None,
 '_restored_from_state': False}

In [6]:
C = len(UD_TAG.vocab)

class Model(nn.Module):
    def __init__(self, hidden, classes):
        super().__init__()
        self.base_model = model_class.from_pretrained(pretrained_weights)
        self.linear = nn.Linear(hidden, C)
        self.transition = nn.Linear(C, C)
        self.dropout = nn.Dropout(config["dropout"])
        
    def forward(self, words, mapper):
        out = self.dropout(self.base_model(words)[0]) # N x H
        out = torch.einsum("bca,bch->bah", mapper.float(), out) #.cuda() # (N x N) (N x H) -> N x H
        final = torch.einsum("bnh,ch->bnc", out, self.linear.weight) # (N x H) (H x C) -> N x C
        batch, N, C = final.shape
        #print(final.view(batch, N, C, 1).shape)
        #print(final.view(batch, N, C, 1)[:, 1:N].shape)
        vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C)
        #print(vals.shape)
        vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] 
        return vals

In [7]:
model = Model(config["H"], C)

x = next(iter(train_iter))
words, mapper, _ = x.word
label, lengths = x.udtag
log_potentials = model(words, mapper)

In [8]:
mapper.shape

torch.Size([1, 30, 22])

In [9]:
words.shape

torch.Size([1, 30])

In [10]:
label.shape

torch.Size([22, 1])

In [11]:
log_potentials.shape

torch.Size([1, 21, 32, 32])

In [12]:
log_potentials[:, 0, :, :].shape

torch.Size([1, 32, 32])

In [13]:
log_potentials[:, 1, :, :].shape

torch.Size([1, 32, 32])

In [14]:
model = Model(config["H"], C)
#wandb.watch(model)
#model.cuda()

def validate(itera):
    incorrect_edges = 0
    total = 0 
    model.eval()
    for i, ex in enumerate(itera):
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        dist = LinearChainCRF(model(words, mapper), #.cuda()
                              lengths=lengths)        
        argmax = dist.argmax
        gold = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C,
                                              lengths=lengths).type_as(argmax)
        incorrect_edges += (argmax.sum(-1) - gold.sum(-1)).abs().sum() / 2.0
        total += argmax.sum()            
        
    model.train()    
    return incorrect_edges / total   
    
def train(train_iter, val_iter, model):
    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)

    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        N_1, batch = label.shape

        # Model
        log_potentials = model(words, mapper) #.cuda()
        if not lengths.max() <= log_potentials.shape[1] + 1:
            print("fail")
            continue

        dist = LinearChainCRF(log_potentials,
                              lengths=lengths) #lengths.cuda()   

        
        labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C, lengths=lengths) \
                            .type_as(dist.log_potentials)
        loss = dist.log_prob(labels).sum()
        (-loss).backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        scheduler.step()

        losses.append(loss.detach())
        
        
        if i % 100 == 10:            
            print(-torch.tensor(losses).mean(), words.shape)
            val_loss = validate(val_iter)
            #wandb.log({"train_loss":-torch.tensor(losses).mean(), 
            #           "val_loss" : val_loss})
            losses = []

In [None]:
train(train_iter, val_iter, model) #.cuda()

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)


tensor(68.1869) torch.Size([1, 25])
tensor(22.1560) torch.Size([1, 74])
tensor(1.0852) torch.Size([1, 13])
tensor(1.2359) torch.Size([1, 13])
tensor(0.4764) torch.Size([1, 20])
tensor(0.1712) torch.Size([1, 41])
tensor(0.0022) torch.Size([1, 35])
tensor(0.0015) torch.Size([1, 13])
tensor(0.0611) torch.Size([1, 74])
tensor(0.0385) torch.Size([1, 41])
tensor(0.0009) torch.Size([1, 25])
tensor(0.0008) torch.Size([1, 20])
