# Good/Bad vocab learning task

In [144]:
import csv
from pathlib import Path
import logging as log
import itertools
import torch

log.basicConfig(level=log.INFO)


In [145]:
all_data_path = Path('good-bad-vocab-learning-task.tsv')

In [146]:
!head -12 {all_data_path}

Experiment	n:	src_toks	src_mean_len	src_EMD	src_effective_n	src_f90p	src_f95p	src_f98p	src_f99p	src_f100p	tgt_toks	tgt_mean_len	tgt_EMD	tgt_effective_n	tgt_f90p	tgt_f95p	tgt_f98p	tgt_f99p	tgt_f100p	BLEU_dev
runs-deen-030k/000-deen-chars-r1	Chars	4,882,920	162.76	0.743	129	22	16	15	15	13	4,369,261	145.64	0.773	128	32	23	19	19	19	14.4
runs-deen-030k/011-deen-.5k.5k-r1	500	2,119,390	70.65	0.420	493	358	39	17	16	13	1,880,587	62.69	0.427	492	122	50	25	20	9	13.3
runs-deen-030k/022-deen-01k01k-r1	1,000	1,749,715	58.32	0.444	982	261	70	22	16	2	1,557,676	51.92	0.463	984	151	59	24	13	1	17.1
runs-deen-030k/033-deen-02k02k-r1	2,000	1,472,822	49.09	0.476	1,969	112	47	17	10	1	1,305,579	43.52	0.497	1,973	89	43	18	6	1	16.5
runs-deen-030k/044-deen-04k04k-r1	4,000	1,255,305	41.84	0.513	3,930	45	21	8	4	1	1,116,064	37.20	0.542	3,932	35	18	6	3	1	16.4
runs-deen-030k/055-deen-08k08k-r1	8,000	1,088,556	36.29	0.557	7,833	18	9	3	2	1	986,651	32.89	0.603	7,804	12	6	2	1	1	14.9
runs-deen-030k/066-deen-16k16k-r1	16,

In [150]:
def read_all(path):
    data = [[]]
    head = None
    with path.open() as p:
        for i, line in enumerate(p):
            line = line.strip()
            if i == 0:
                head = line.split('\t')
                continue
            if not line:
                # empty lines separate suit
                data.append([])
                continue
            rec = line.split('\t')
            if len(rec) != len(head):
                log.warning(f'skip: {line}')
                continue
            assert len(rec) == len(head)
            for i in range(1, len(rec)):
                if rec[i] == 'Chars':
                    rec[i] = "100"  # some placeholder
                rec[i] = float(rec[i].replace(',', ''))
            data[-1].append(rec)
    return head, data
header, all_data = read_all(all_data_path)
print(len(all_data), sum(1 for grp in all_data for ex in grp))

12 113


In [140]:
class MyDataset(torch.utils.data.Dataset):
    
    NEG_CLS = 0
    POS_CLS = 1
    SAME_CLS = 2
    
    def __init__(self, header, data, has_name=True, has_score=True, significance=0.1):
        super().__init__()
        # data comes in groups, here we flatten it
        # first column may have example name, last column may have BLEU score
        self.feat_names = header
        self.names = None
        assert significance >= 0.
        self.significance = significance # change in score
        if has_name:
            self.feat_names = self.feat_names[1:]
            self.names = [ex[0] for grp in data for ex in grp]
            data = [[ex[1:] for ex in grp] for grp in data]
        self.scores = None
        if has_score:
            self.feat_names = self.feat_names[:-1]            
            self.scores = [ex[-1] for grp in data for ex in grp]
            data = [[ex[:-1] for ex in grp] for grp in data]

        n = sum(1 for grp in data for ex in grp)

        self.data = torch.zeros(n, len(self.feat_names), dtype=torch.float)
        self.groups = []
        idx = 0
        for group in data:
            self.groups.append((idx, idx + len(group)))
            for ex in group:
                self.data[idx] = torch.tensor(ex, dtype=torch.float)
                idx += 1
        self.n_feats = len(self.feat_names) * 2 # concatenation
        self.n_class = 3
        self.pairs = self.make_pairs()

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        lhs, rhs = self.names[pair[0]], self.names[pair[1]]
        # lets use lhs or x1,y1 as anchor
        #      rhs or x2,y2 as comparison  
        # if y2 << y1 then its Negative, i.e damaging  => Class 0
        # if y2 >> y1 then its Positive, i.e improvement => class 1
        # if y2 ~ y1 within significance, then same  => Class 2

        x1, x2 = self.data[pair[0]], self.data[pair[1]]
        x = torch.cat((x1, x2), dim=0)
        if self.scores:
            y1, y2 = self.scores[pair[0]], self.scores[pair[1]]
            if abs(y2 - y1) <= self.significance:
                y = self.SAME_CLS
            elif y2 < y1:
                y = self.NEG_CLS
            else:
                y = self.POS_CLS
            return x, y
        return x

    def _prepare(self, data):

        # total
        n = sum(1 for grp in data for ex in grp)
        data_set = torch.zeros(n, len(header), dtype=torch.float)
        groups = []
        # make pairs within groups
        idx = 0
        for group in data:
            groups.append((idx, idx + len(group)))
            for ex in group:
                names.append(ex[0])
                data_set[idx] = torch.tensor(ex[1:] , dtype=torch.float)
                idx += 1
        return data_set, groups, names
    
    def make_pairs(self):
        pairs = []
        for start, end in self.groups:
            for lhs in range(start, end):
                for rhs in range(start, end):
                    pairs.append([lhs, rhs])
        return torch.tensor(pairs, dtype=torch.int)

train_ds =  MyDataset(head, all_data[2:], has_name=True, has_score=True)
valid_ds =  MyDataset(head, all_data[0:1], has_name=True, has_score=True)
test_ds =  MyDataset(head, all_data[1:2], has_name=True, has_score=True)
log.info(f'Train: {len(train_ds)} Validn: {len(valid_ds)} Test: {len(test_ds)} examples')

INFO:root:Train: 713 Validn: 64 Test: 100 examples


In [141]:
params = {'batch_size': 24,
          'shuffle': True,
          'num_workers': 1}

train_gen = torch.utils.data.DataLoader(train_ds, **params)
valid_gen = torch.utils.data.DataLoader(test_ds, **params)

In [142]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm.auto import tqdm
import numpy as np 

class Classifier(nn.Module):
    
    def __init__(self, n_feats, n_class, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(n_feats),
            nn.Linear(n_feats, n_feats//2),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(n_feats//2, n_feats),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(n_feats, n_class)
        )
    
    def forward(self, batch_x, get='logprobs'):
        scores = self.net(batch_x)
        if get == 'logit':
            return scores
        elif get == 'prob':
            return F.softmax(scores, dim=-1)
        elif get == 'logprob':
            return F.log_softmax(scores, dim=-1)
        else:
            raise Exection(f'Unkown: {get}; Known: logit, prob, logprob')

class Trainer:
    
    def __init__(self, model, optim=None, loss_func=None):
        self.model = model.train()
        self.optim = optim or torch.optim.Adam(model.parameters())
        self.loss_func = loss_func or nn.NLLLoss(reduction='none')
        
    def validate(self, data):
        ep_losses = []
        truth = []
        preds = []
        for xs, ys in data:
            logprobs = self.model(xs, get='logprob')
            truth.extend(ys.tolist())
            preds.extend(logprobs.argmax(dim=-1).tolist())
            loss = self.loss_func(logprobs, ys).sum(dim=-1).mean(dim=0)
            ep_losses.append(loss.item())

        evalr = MultiClassEvaluator(truth, preds)
        print(np.round(evalr.f1s, 2))
        return sum(ep_losses) / len(ep_losses)

            
    def train(self, train_data, val_data, epochs, patience=20):
        
        with tqdm(range(epochs), total=epochs, unit='epoch') as ep_bar:
            val_losses = []
            n_steps = 0
            for ep in ep_bar:
                tr_losses = []
                for xs, ys in train_data:
                    self.optim.zero_grad()
                    logprobs = self.model(xs, get='logprob')
                    loss = self.loss_func(logprobs, ys).sum(dim=-1).mean(dim=0)
                    loss.backward()
                    tr_losses.append(loss.item())
                    self.optim.step()
                    n_steps += 1

                train_loss = sum(tr_losses) / len(tr_losses)
                with torch.no_grad():
                    model.eval()
                    val_loss = self.validate(val_data)
                    val_losses.append(val_loss)
                    model.train()
                ep_bar.set_postfix(train_loss=train_loss, val_loss=val_loss, n_updates=n_steps, refresh=False)

                if len(val_losses) > patience + 1:
                    if val_losses[-patience -1] <= min(val_losses[-patience:]):
                        log.info(f"Early stop at epoch {ep+1}, at step = {n_steps}")
                        break


class MultiClassEvaluator():

    def __init__(self, truth, preds):       
        assert len(truth) == len(preds)
        self.truth = truth
        self.preds = preds
        self.classes = list(set(truth) | set(preds))
        self.n_classes = 1 + max(self.classes)
        self.mat = self.confusion_matrix()
        self.precs, self.recs, self.f1s = [np.zeros(self.n_classes, dtype=np.float)
                                        for _ in range(3)]
        for c in range(self.n_classes):
            tot_tru = self.mat[c, :].sum()
            tot_pred = self.mat[:, c].sum()
            assert self.mat[c, c] <= tot_tru
            assert self.mat[c, c] <= tot_pred
            self.precs[c] =  self.mat[c, c] / tot_pred if tot_pred > 0 else 1
            self.recs[c] =  self.mat[c, c] / tot_tru if tot_tru > 0 else 1
            if self.recs[c] + self.precs[c] > 0:
                self.f1s[c] =  2 * self.precs[c] * self.recs[c] / (self.recs[c] + self.precs[c])


    def confusion_matrix(self):
        mat = np.zeros((self.n_classes, self.n_classes), dtype=np.int)
        for tr, pr in zip(self.truth, self.preds):
            mat[tr][pr] += 1
        return mat
    

In [143]:
model = Classifier(train_ds.n_feats, train_ds.n_class)

trainer = Trainer(model)
trainer.train(train_gen, valid_gen, epochs=20000, patience=40)

HBox(children=(FloatProgress(value=0.0, max=20000.0), HTML(value='')))

[0.61 0.   0.  ]
[0.51 0.38 0.  ]
[0.41 0.42 0.  ]
[0.41 0.4  0.  ]
[0.38 0.41 0.  ]
[0.42 0.43 0.  ]
[0.42 0.43 0.  ]
[0.41 0.42 0.  ]
[0.41 0.38 0.  ]
[0.41 0.4  0.  ]
[0.41 0.4  0.  ]
[0.42 0.41 0.  ]
[0.41 0.38 0.  ]
[0.41 0.4  0.  ]
[0.41 0.4  0.  ]
[0.4  0.36 0.  ]
[0.4  0.36 0.  ]
[0.41 0.41 0.8 ]
[0.41 0.38 0.44]
[0.38 0.41 0.44]
[0.42 0.41 0.71]
[0.42 0.41 0.71]
[0.42 0.41 0.74]
[0.41 0.42 0.74]
[0.42 0.42 0.67]
[0.42 0.42 0.69]
[0.42 0.42 0.67]
[0.41 0.42 0.65]
[0.43 0.42 0.67]
[0.42 0.42 0.71]
[0.42 0.42 0.71]
[0.41 0.42 0.71]
[0.42 0.42 0.71]
[0.41 0.41 0.69]
[0.42 0.42 0.69]
[0.42 0.41 0.65]
[0.42 0.42 0.69]
[0.42 0.42 0.69]
[0.42 0.41 0.71]
[0.42 0.42 0.69]
[0.42 0.42 0.71]
[0.42 0.42 0.71]
[0.42 0.42 0.69]
[0.42 0.42 0.67]
[0.42 0.42 0.69]
[0.42 0.42 0.67]
[0.45 0.41 0.67]
[0.42 0.42 0.69]
[0.42 0.42 0.67]
[0.41 0.42 0.62]
[0.42 0.42 0.67]
[0.46 0.42 0.62]
[0.41 0.42 0.61]
[0.43 0.42 0.65]
[0.42 0.42 0.69]
[0.42 0.42 0.67]
[0.48 0.41 0.67]
[0.49 0.41 0.62]
[0.42 0.42 0.6

INFO:root:Early stop at epoch 368, at step = 11040


[0.55 0.55 0.8 ]

