In [27]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class Encoder(nn.Module):
    def __init__(self, num_classes, num_support_per_class,
                 vocab_size, embed_size, hidden_size,
                 output_dim, weights):
        super(Encoder, self).__init__()
        self.num_support = num_classes * num_support_per_class
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(vocab_size, embed_size)
        if weights is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(weights))

        self.bilstm = nn.LSTM(embed_size, hidden_size, num_layers=1, bidirectional=True, batch_first=True)
        self.fc1 = nn.Linear(2 * hidden_size, output_dim)
        self.fc2 = nn.Linear(output_dim, output_dim)

    def attention(self, x):
        weights = torch.tanh(self.fc1(x))
        weights = self.fc2(weights)  # (batch=k*c, seq_len, d_a)
        batch, seq_len, d_a = weights.shape
        weights = weights.transpose(1, 2)  # (batch=k*c, d_a, seq_len)
        weights = weights.contiguous().view(-1, seq_len)
        weights = F.softmax(weights, dim=1).view(batch, d_a, seq_len)
        sentence_embeddings = torch.bmm(weights, x)  # (batch=k*c, d_a, 2*hidden)
        avg_sentence_embeddings = torch.mean(sentence_embeddings, dim=1)  # (batch, 2*hidden)
        return avg_sentence_embeddings

    def forward(self, x, hidden=None):
        batch_size, _ = x.shape
        if hidden is None:
            h = x.data.new(2, batch_size, self.hidden_size).fill_(0).float()
            c = x.data.new(2, batch_size, self.hidden_size).fill_(0).float()
        else:
            h, c = hidden
        x = self.embedding(x)
        outputs, _ = self.bilstm(x, (h, c))  # (batch=k*c,seq_len,2*hidden)
        outputs = self.attention(outputs)  # (batch=k*c, 2*hidden)
        # (c*s, 2*hidden_size), (c*q, 2*hidden_size)
        support, query = outputs[0: self.num_support], outputs[self.num_support:]
        # print('support, query: {} {}'.format(support.shape, query.shape))
        return support, query


class Induction(nn.Module):
    def __init__(self, C, S, H, iterations):
        super(Induction, self).__init__()
        self.C = C
        self.S = S
        self.H = H
        self.iterations = iterations
        self.W = torch.nn.Parameter(torch.randn(H, H))

    def forward(self, x):
        b_ij = torch.zeros(self.C, self.S).to(x)
        for _ in range(self.iterations):
            d_i = F.softmax(b_ij.unsqueeze(2), dim=1)  # (C,S,1)
            a = x.reshape(-1, self.H)
            b = torch.mm(a, self.W)
            #print('x: {}'.format(x.shape))
            #print(a.shape, b.shape)
            e_ij = torch.mm(x.reshape(-1, self.H), self.W).reshape(self.C, self.S, self.H)  # (C,S,H)
            c_i = torch.sum(d_i * e_ij, dim=1)  # (C,H)
            # squash
            squared = torch.sum(c_i ** 2, dim=1).reshape(self.C, -1)
            coeff = squared / (1 + squared) / torch.sqrt(squared + 1e-9)
            c_i = coeff * c_i
            c_produce_e = torch.bmm(e_ij, c_i.unsqueeze(2))  # (C,S,1)
            b_ij = b_ij + c_produce_e.squeeze(2)

        return c_i


class Relation(nn.Module):
    def __init__(self, C, H, out_size):
        super(Relation, self).__init__()
        self.out_size = out_size
        self.M = torch.nn.Parameter(torch.randn(H, H, out_size))
        self.W = torch.nn.Parameter(torch.randn(C * out_size, C))
        self.b = torch.nn.Parameter(torch.randn(C))

    def forward(self, class_vector, query_encoder):  # (C,H) (Q,H)
        mid_pro = []
        for slice in range(self.out_size):
            slice_inter = torch.mm(torch.mm(class_vector, self.M[:, :, slice]), query_encoder.transpose(1, 0))  # (C,Q)
            mid_pro.append(slice_inter)
        mid_pro = torch.cat(mid_pro, dim=0)  # (C*out_size,Q)
        V = F.relu(mid_pro.transpose(0, 1))  # (Q,C*out_size)
        probs = torch.sigmoid(torch.mm(V, self.W) + self.b)  # (Q,C)
        return probs

In [7]:
import os
import torch
#from model import FewShotInduction
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from glob import glob
from tqdm import tqdm
from torch import optim
from torch.nn.utils.rnn import pad_sequence
#from criterion import Criterion

### 토크나이저 변수 선언

In [8]:
# 반드시 do_lower_case=True로 해야 한다.
# bert-base-uncased는 영어 데이터를 소문자로 변환해서 학습한 모델이기 때문이다.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

### 데이터셋과 데이터로더 만들기

In [10]:
class AmazonDataset():
    def __init__(self, data_path, tokenizer, dtype):
        self.data_path = data_path
        self.tokenizer = tokenizer
        with open(f'{dtype}.list', 'r') as f:
            self.categories = [oneline.rstrip() for oneline in f]
        self.support_dataset = {}
        self.dataset = {}
        for category in tqdm(self.categories, desc='reading categories'):
            self.dataset[category] = {
                'neg': self.get_data(category, 'neg', dtype),
                'pos': self.get_data(category, 'pos', dtype)
            }
        
        if dtype == 'test' or dtype == 'dev':
            for category in tqdm(self.categories, desc='reading categories for support'):
                self.support_dataset[category] = {
                    'neg': self.get_data(category, 'neg', 'train'),
                    'pos': self.get_data(category, 'pos', 'train'),
                }
        
    def read_files(self, category, label, dtype):
        data = {
            'text': [],
            'label': []
        }
        for t in ['t2', 't4', 't5']:
            filename = f'{category}.{t}.{dtype}'
            with open(os.path.join(self.data_path, filename), 'r', encoding='utf-8') as f:
                for oneline in f:
                    oneline = oneline.rstrip()
                    text = oneline[:-2]
                    if int(oneline[-2:]) == 1 and label == 'pos':
                        tensor = self.tokenizer(text, return_tensors='pt')
                        data['text'].append(tensor['input_ids'][0])
                        data['label'].append(1)
                    elif int(oneline[-2:]) == -1 and label == 'neg':
                        tensor = self.tokenizer(text, return_tensors='pt')
                        data['text'].append(tensor['input_ids'][0])
                        data['label'].append(0)
        data['label'] = torch.tensor(data['label'])
        return data
    
    def get_data(self, category, label, dtype):
        data = self.read_files(category, label, dtype)
        return data

In [11]:
data_path = 'Amazon_few_shot'

In [13]:
train_dataset = AmazonDataset(data_path, tokenizer, 'train')
dev_dataset = AmazonDataset(data_path, tokenizer, 'dev')
test_dataset = AmazonDataset(data_path, tokenizer, 'test')

reading categories:   0%|                                                                       | 0/14 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (625 > 512). Running this sequence through the model will result in indexing errors
reading categories: 100%|██████████████████████████████████████████████████████████████| 14/14 [02:41<00:00, 11.55s/it]
reading categories: 100%|████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.05it/s]
reading categories for support: 100%|████████████████████████████████████████████████████| 5/5 [00:07<00:00,  1.54s/it]
reading categories: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.56s/it]
reading categories for support: 100%|████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14.65it/s]


In [14]:
def pad_text(a_text, b_text):
    a_text_len = a_text.shape[1]
    b_text_len = b_text.shape[1]

    if a_text_len > b_text_len:
        b_text = torch.cat([b_text, torch.zeros(b_text.shape[0], a_text_len-b_text_len).long()], dim=1)
    else:
        a_text = torch.cat([a_text, torch.zeros(a_text.shape[0], b_text_len-a_text_len).long()], dim=1)
        
    return a_text, b_text

In [15]:
class AmazonDataLoader():
    def __init__(self, dataset, batch_size, n_support):
        assert n_support % 2 == 0, 'n_support should be multiple of 2'
        self.dataset = dataset
        self.batch_size = batch_size
        self.n_support = n_support
        self.neg_idx = {k:0 for k in dataset.dataset}
        self.pos_idx = {k:0 for k in dataset.dataset}
        self.neg_len = {k:len(dataset.dataset[k]['neg']['text']) for k in dataset.dataset}
        self.pos_len = {k:len(dataset.dataset[k]['pos']['text']) for k in dataset.dataset}
        self.neg = {k:dataset.dataset[k]['neg'] for k in dataset.dataset}
        self.pos = {k:dataset.dataset[k]['pos'] for k in dataset.dataset}
        self.idx = 0
        self.categories = [k for k in dataset.dataset]
        
        # prepare for test dataset, support dataset should come from "*.train"
        self.neg_support_idx = {}
        self.pos_support_idx = {}
        self.neg_support_len = {}
        self.pos_support_len = {}
        if self.dataset.support_dataset:
            self.neg_support_idx = {k:0 for k in self.dataset.support_dataset}
            self.pos_support_idx = {k:0 for k in self.dataset.support_dataset}
            self.neg_support_len = {k:len(self.dataset.support_dataset[k]['neg']['text']) for k in self.dataset.support_dataset}
            self.pos_support_len = {k:len(self.dataset.support_dataset[k]['pos']['text']) for k in self.dataset.support_dataset}
        
    def get_batch(self):
        category = self.categories[self.idx % len(self.categories)]
        neg = self.neg[category]
        pos = self.pos[category]
        neg_start_idx = self.neg_idx[category] % self.neg_len[category]
        pos_start_idx = self.pos_idx[category] % self.pos_len[category]
        
        # prepare negative/positive dataset
        neg_text = neg['text'][neg_start_idx:neg_start_idx+(self.batch_size//2)]
        pos_text = pos['text'][pos_start_idx:pos_start_idx+(self.batch_size//2)]
        neg_label = neg['label'][neg_start_idx:neg_start_idx+(self.batch_size//2)]
        pos_label = pos['label'][pos_start_idx:pos_start_idx+(self.batch_size//2)]
        self.neg_idx[category] += (self.batch_size//2)
        self.pos_idx[category] += (self.batch_size//2)
        
        if len(neg_text) + len(pos_text) != self.batch_size:
            return self.get_batch()
            
        # padding text dataset
        neg_text = pad_sequence([n for n in neg_text], batch_first=True)
        pos_text = pad_sequence([p for p in pos_text], batch_first=True)
        neg_text, pos_text = pad_text(neg_text, pos_text)
            
        # prepare support/query text
        neg_support_text = neg_text[:self.n_support//2]
        pos_support_text = pos_text[:self.n_support//2]
        neg_query_text = neg_text[self.n_support//2:]
        pos_query_text = pos_text[self.n_support//2:]
        
        # prepare support/query label
        neg_support_label = neg_label[:self.n_support//2]
        pos_support_label = pos_label[:self.n_support//2]
        neg_query_label = neg_label[self.n_support//2:]
        pos_query_label = pos_label[self.n_support//2:]
        
        # merge support/query text
        support_text = torch.cat([neg_support_text, pos_support_text], dim=0)
        query_text = torch.cat([neg_query_text, pos_query_text], dim=0)
        
        # merge support/query label
        support_label = torch.cat([neg_support_label, pos_support_label], dim=0)
        query_label = torch.cat([neg_query_label, pos_query_label], dim=0)
        
        # make data and label
        data = torch.cat([support_text, query_text], dim=0)
        label = torch.cat([support_label, query_label], dim=0)
        
        # increase category index
        self.idx += 1
        return data, label
    
    def get_batch_test(self):
        assert self.dataset.support_dataset, 'support_dataset is empty'
        
        category = self.categories[self.idx % len(self.categories)]
        neg = self.neg[category]
        pos = self.pos[category]
        neg_query_start_idx = self.neg_idx[category] % self.neg_len[category]
        pos_query_start_idx = self.pos_idx[category] % self.pos_len[category]
        neg_support_start_idx = self.neg_support_idx[category] % self.neg_support_len[category]
        pos_support_start_idx = self.pos_support_idx[category] % self.pos_support_len[category]
        
        # prepare negative/positive support dataset from support_dataset
        category_suuport_dataset = self.dataset.support_dataset[category]
        neg_support_text = category_suuport_dataset['neg']['text'][neg_support_start_idx:neg_support_start_idx+self.n_support//2]
        pos_support_text = category_suuport_dataset['pos']['text'][pos_support_start_idx:pos_support_start_idx+self.n_support//2]
        neg_support_label = category_suuport_dataset['neg']['label'][neg_support_start_idx:neg_support_start_idx+self.n_support//2]
        pos_support_label = category_suuport_dataset['pos']['label'][pos_support_start_idx:pos_support_start_idx+self.n_support//2]
        self.neg_support_idx[category] += (self.n_support//2)
        self.pos_support_idx[category] += (self.n_support//2)
        
        # prepare negative/positive query dataset
        neg_query_text = neg['text'][neg_query_start_idx:neg_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        pos_query_text = pos['text'][pos_query_start_idx:pos_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        neg_query_label = neg['label'][neg_query_start_idx:neg_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        pos_query_label = pos['label'][pos_query_start_idx:pos_query_start_idx+(self.batch_size//2 - self.n_support//2)]
        self.neg_idx[category] += (self.batch_size//2 - self.n_support//2)
        self.pos_idx[category] += (self.batch_size//2 - self.n_support//2)
        
        # padding support text dataset
        if self.n_support:
            neg_support_text = pad_sequence([n for n in neg_support_text], batch_first=True)
            pos_support_text = pad_sequence([n for n in pos_support_text], batch_first=True)
            neg_support_text, pos_support_text = pad_text(neg_support_text, pos_support_text)
        else:
            neg_support_text = torch.tensor([[]])
            pos_support_text = torch.tensor([[]])
            
        # padding text dataset
        neg_query_text = pad_sequence([n for n in neg_query_text], batch_first=True)
        pos_query_text = pad_sequence([p for p in pos_query_text], batch_first=True)
        neg_query_text, pos_query_text = pad_text(neg_query_text, pos_query_text)

        # concatenating support/query text dataset
        support_text = torch.cat([neg_support_text, pos_support_text], dim=0)
        query_text = torch.cat([neg_query_text, pos_query_text], dim=0)
        support_text, query_text = pad_text(support_text, query_text)

        # make final data and label
        if self.n_support:
            data = torch.cat([support_text, query_text], dim=0)
        else:
            data = query_text
        label = torch.cat([neg_support_label, pos_support_label, neg_query_label, pos_query_label], dim=0)
        return data, label

In [16]:
support = 5

In [17]:
train_dataloader = AmazonDataLoader(train_dataset, batch_size=64, n_support=support*2)
dev_dataloader = AmazonDataLoader(dev_dataset, batch_size=64, n_support=support*2)
test_dataloader = AmazonDataLoader(test_dataset, batch_size=64, n_support=support*2)

In [18]:
for i in range(10):
    d, l = train_dataloader.get_batch()
    print(d.shape, l.float().mean())

torch.Size([64, 149]) tensor(0.5000)
torch.Size([64, 460]) tensor(0.5000)
torch.Size([64, 254]) tensor(0.5000)
torch.Size([64, 262]) tensor(0.5000)
torch.Size([64, 1283]) tensor(0.5000)
torch.Size([64, 1658]) tensor(0.5000)
torch.Size([64, 613]) tensor(0.5000)
torch.Size([64, 359]) tensor(0.5000)
torch.Size([64, 530]) tensor(0.5000)
torch.Size([64, 602]) tensor(0.5000)


In [19]:
for i in range(10):
    d, l = dev_dataloader.get_batch_test()
    print(d.shape, l.float().mean())

torch.Size([64, 327]) tensor(0.5000)
torch.Size([55, 181]) tensor(0.5818)
torch.Size([64, 198]) tensor(0.5000)
torch.Size([46, 295]) tensor(0.6957)
torch.Size([64, 197]) tensor(0.5000)
torch.Size([64, 276]) tensor(0.5000)
torch.Size([55, 186]) tensor(0.5818)
torch.Size([64, 270]) tensor(0.5000)
torch.Size([26, 130]) tensor(0.4615)
torch.Size([64, 327]) tensor(0.5000)


In [20]:
for i in range(10):
    d, l = test_dataloader.get_batch_test()
    print(d.shape, l.float().mean())

torch.Size([64, 743]) tensor(0.5000)
torch.Size([64, 841]) tensor(0.5000)
torch.Size([64, 1386]) tensor(0.5000)
torch.Size([64, 706]) tensor(0.5000)
torch.Size([64, 1026]) tensor(0.5000)
torch.Size([64, 1126]) tensor(0.5000)
torch.Size([64, 1116]) tensor(0.5000)
torch.Size([64, 1333]) tensor(0.5000)
torch.Size([64, 568]) tensor(0.5000)
torch.Size([64, 570]) tensor(0.5000)


### 아마존 데이터셋에 대한 메타러닝 모델 클래스 정의하기

In [28]:
class FewShotInduction(nn.Module):
    def __init__(self, C, S, vocab_size, embed_size, hidden_size, d_a,
                 iterations, outsize, weights=None):
        super(FewShotInduction, self).__init__()
        self.encoder = Encoder(C, S, vocab_size, embed_size, hidden_size, d_a, weights)
        self.induction = Induction(C, S, 2 * hidden_size, iterations)
        self.relation = Relation(C, 2 * hidden_size, outsize)

    def forward(self, x):
        support_encoder, query_encoder = self.encoder(x)  # (k*c, 2*hidden_size)
        # print('** support_encoder: {}'.format(support_encoder.shape))
        # print('** query_encoder: {}'.format(query_encoder.shape))
        class_vector = self.induction(support_encoder)
        # print('** class_vector: {}'.format(class_vector.shape))
        probs = self.relation(class_vector, query_encoder)
        # print('** relation: {}'.format(probs.shape))
        return probs

### 아마존 데이터셋에 대한 메타러닝 모델 객체 정의하기

In [21]:
model = FewShotInduction(C=2,
                         S=support,
                         vocab_size=len(tokenizer),
                         embed_size=300,
                         hidden_size=128,
                         d_a=64,
                         iterations=3,
                         outsize=100)
if torch.cuda.is_available():
    model = model.cuda()

In [22]:
len(tokenizer)

30522

In [23]:
optimizer = optim.Adam(model.parameters(), lr=float(1e-4))

In [29]:
import torch
from torch.nn.modules.loss import _Loss


class Criterion(_Loss):
    def __init__(self, way=2, shot=5):
        super(Criterion, self).__init__()
        self.amount = way * shot

    def forward(self, probs, target, return_pred_label=False):  # (Q,C) (Q)
        target = target[self.amount:]
        target_onehot = torch.zeros_like(probs)
        #print('** sum of probs/target_onehot: {} {}'.format(probs.sum(), target_onehot.sum()))
        target_onehot = target_onehot.scatter(1, target.reshape(-1, 1), 1)
        loss = torch.mean((probs - target_onehot) ** 2)
        pred = torch.argmax(probs, dim=1)
        acc = torch.sum(target == pred).float() / target.shape[0]
        #print('** acc: {}'.format(acc))

        if return_pred_label:
            return loss, acc, pred, target
        else:
            return loss, acc

In [24]:
criterion = Criterion(way=2, shot=support)

In [36]:
def train(episode):
    model.train()
    data, target = train_dataloader.get_batch()
    if torch.cuda.is_available():
        data = data.cuda()
        target = target.cuda()

    optimizer.zero_grad()
    predict = model(data)
    loss, acc = criterion(predict, target)
    loss.backward()
    optimizer.step()
    return loss

In [37]:
def dev(episode):
    model.eval()
    correct = 0.
    count = 0.
    for i in range(100):
        data, target = dev_dataloader.get_batch_test()
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
            
        predict = model(data)
        _, acc = criterion(predict, target)
        amount = len(target) - support * 2
        correct += acc * amount
        count += amount
    acc = correct / count
    return acc

In [30]:
dev_interval = 100
best_acc = -1.0

In [38]:
tbar = tqdm(range(1, 10000))
for episode in tbar:
    
    loss = train(episode)
    if episode % dev_interval == 0:
        acc = dev(episode)
        if acc > best_acc:
            print('Better acc! Saving model! -> {:.4f}'.format(acc))
            best_acc = acc
    tbar.set_postfix(loss=loss)  

  0%|                                | 3/9999 [00:17<15:47:47,  5.69s/it, loss=tensor(0.4907, grad_fn=<MeanBackward0>)]


KeyboardInterrupt: 

In [39]:
torch.save(model.state_dict(), f'fewshot_model_{support}.bin')

### 학습된 메타러닝 모델 평가하기

In [40]:
support = 5
criterion = Criterion(way=2, shot=support)

In [41]:
model2 = FewShotInduction(C=2,
                         S=support,
                         vocab_size=len(tokenizer),
                         embed_size=300,
                         hidden_size=128,
                         d_a=64,
                         iterations=3,
                         outsize=100)

In [43]:
model2.load_state_dict(torch.load(f'./fewshot_model_{support}.bin', map_location='cpu'))

<All keys matched successfully>

In [44]:
def test():
    model.eval()
    correct = 0.
    count = 0.
    for i in range(100):
        data, target = test_dataloader.get_batch_test()
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
            
        predict = model(data)
        _, acc = criterion(predict, target)
        amount = len(target) - support * 2
        correct += acc * amount
        count += amount
        
    acc = correct / count
    print('Test Acc: {}'.format(acc))
    return acc

In [45]:
test()

KeyboardInterrupt: 