In [1]:
import os
import random
from argparse import Namespace
import copy

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Sampler, DataLoader
from tqdm import tqdm

from fairseq.modules import TransformerEncoderLayer

In [2]:
langs = ["fr", "de", "es", "it", "ru", "pt", "nl", "sv-SE", "sl"] # "zh-CN", "pt", "fa", "et", "mn", "nl",
         # "tr", "ar", "sv-SE", "lv", "sl", "ta", "ja", "id"]
device = 'cuda:2'

In [3]:
class MyDataset(Dataset):
    def __init__(self, langs):
        self.langs = langs
        self.data = []
        self.labels = []
        for lang_id, lang in enumerate(langs):
            for batch_idx in tqdm(os.listdir('/mnt/raid0/siqi/analysis/resources/{}'.format(lang)), desc='Lang {}'.format(lang)):
                encoder_out = th.load('/mnt/raid0/siqi/analysis/resources/{}/{}'.format(lang, batch_idx), map_location='cpu')[0]
                x = encoder_out.encoder_out
                padding_mask = encoder_out.encoder_padding_mask
                bsz, seqlen = padding_mask.size()
                for i in range(bsz):
                    x_i = x[:, i, :][~padding_mask[i]]
                    y_i = lang_id
                    self.data.append(x_i)
                    self.labels.append(y_i)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def __len__(self):
        return len(self.data)
    
    def collate(self, indices):
        inputs = []
        labels = []

        batch = [self[idx] for idx in indices]
        max_len = max(data.size(0) for data, _ in batch)
        padding_mask = th.zeros(len(batch), max_len).bool()
        for i, (data, label) in enumerate(batch):
            if data.size(0) < max_len:
                padding_mask[i, -(max_len - data.size(0)):] = True
                zeros = th.zeros(max_len - data.size(0), data.size(-1))
                inputs.append(th.cat([data, zeros], dim=0).unsqueeze(1))
            else:
                inputs.append(data.unsqueeze(1))
            labels.append(label)

        inputs = th.cat(inputs, dim=1)
        labels = th.LongTensor(labels)

        return inputs.to(device), padding_mask.to(device), labels.to(device)


In [4]:
dataset = MyDataset(langs)

FileNotFoundError: [Errno 2] No such file or directory: '/mnt/raid0/siqi/analysis/resources/fr'

In [None]:
class BatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        super().__init__(dataset)
        self.dataset = dataset
        self.seqlens = [dataset[i][0].size(0) for i in range(len(dataset))]
        self.sorted_indices = sorted(zip(self.seqlens, range(len(dataset))))
        self.batch_size = batch_size

        self.all_batch_indices = []
        sum_len = 0
        batch_indices = []
        for i, seqlen in enumerate(self.seqlens):
            if sum_len + seqlen <= batch_size:
                sum_len += seqlen
                batch_indices.append(self.sorted_indices[i][1])
            else:
                self.all_batch_indices.append(batch_indices)
                sum_len = seqlen
                batch_indices = [self.sorted_indices[i][1]]
        if sum_len > 0:
            self.all_batch_indices.append(batch_indices)
            
    def __len__(self):
        return len(self.all_batch_indices)

    def __iter__(self):
        permuted_all_batch_indices = copy.deepcopy(self.all_batch_indices)
        random.shuffle(permuted_all_batch_indices)
        for batch_indices in permuted_all_batch_indices:
            yield self.dataset.collate(batch_indices)

In [None]:
batch_sampler = BatchSampler(dataset, 10000)

In [None]:
class Classifier(nn.Module):
    def __init__(self, nlayer, ndim, nhid, nhead, nclass, drop=0.1) -> None:
        super().__init__()
        
        args = {
            'encoder_embed_dim': ndim,
            'encoder_attention_heads': nhead,
            'attention_dropout': drop,
            'dropout': drop,
            'activation_dropout': drop,
            'encoder_normalize_before': True,
            'encoder_ffn_embed_dim': nhid, 
        }
        args = Namespace(**args)

        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(args) for _ in range(nlayer)]
        )
        self.linear = nn.Linear(ndim, nclass)

    def forward(self, x, padding_mask):
        for layer in self.layers:
            x = layer(x, padding_mask)
        logits = self.linear(x)[0]
        return logits


In [None]:
classifier = Classifier(2, 1024, 4096, 4, len(langs)).to('cuda:4')

In [None]:
optimizer = th.optim.Adam(classifier.parameters(), lr=1e-4)

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
n_epoch = 20

In [18]:
classifier.train()
for _ in range(n_epoch):
    iterator = tqdm(batch_sampler)
    sum_loss = 0
    cnt = 0
    for inputs, padding_mask, labels in iterator:
        optimizer.zero_grad()
        logits = classifier(inputs, padding_mask)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        
        sum_loss += loss.item()
        cnt += 1
        iterator.set_description('Epoch {} Loss {:.2f}'.format(_ + 1, sum_loss / cnt))

Epoch 1 Loss 0.19: 100%|██████████| 464/464 [00:34<00:00, 13.46it/s]
Epoch 2 Loss 0.12: 100%|██████████| 464/464 [00:45<00:00, 10.27it/s]
Epoch 3 Loss 0.06:  28%|██▊       | 129/464 [00:10<00:26, 12.65it/s]


KeyboardInterrupt: 