In [1]:
import torch
import torch.nn.functional as F
import itertools
import numpy as np
import os
import torch_scatter


os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
def oddness(vectors):
    # inputs: vectors [n,d]
    # assumes they come in +/- form, so convert to 0/1 first
    # returns: labels [n], 0 for even, 1 for odd (i.e. XOR)
    return torch.relu(vectors).sum(1) % 2

In [3]:
def xor_noise(n,a,i):
    # returns two tensors, features and labels
    #   features: [n*a^2, a+i]
    #   labels: [n*a^2]
    # the first a components of the features are 'active' and determine the label
    # through XOR (oddness), the next i are drawn at random and are 'inactive' i.e.
    # they do not determine the label
    
    lst = list(itertools.product([0, 1], repeat=a))
    noiseless = 2*torch.Tensor(lst)-1
    labels = oddness(noiseless)
    # repeat n times
    noiseless = noiseless.repeat(n,1)
    labels = labels.repeat(n)

    noise = 2*torch.randint(0,2,(n*2**a,i))-1
    
    return torch.cat((noiseless,noise),1), labels

In [4]:
def partitioned_xor(n,a,i,s,train=True):
    """
    XOR over a elements with i inactive, with the active elements in those
    up to s if train, else after s.
    
    Returns:
     - features: [n*a^2, a+i]
     - labels: [n*a^2]
    
    Generates using xor_noise, where the active elements are the first a,
    and then permutes as appropriate.
    """
    
    if s < a:
        raise ValueError('s needs to be larger than a')
    if s > a+i:
        raise ValueError('s needs to be smaller than a+i')
    if a > i:
        raise ValueError('a needs to be smaller than i')
        
    inputs, labels = xor_noise(n,a,i)
    
    if train:
        # permute the elements up to s
        idx = np.arange(s)
        np.random.shuffle(idx)
        perm = torch.tensor(idx)
        inputs[:,torch.tensor(np.arange(s))]= inputs[:,perm]
    else:
        # flip so that the active elements are last
        reverse = torch.tensor(np.flip(np.arange(a+i)).copy())
        inputs = inputs[:,reverse]
        # permute the elements after s
        idx = np.arange(s,a+i)
        np.random.shuffle(idx)
        perm = torch.tensor(idx)
        inputs[:,torch.tensor(np.arange(s,a+i))] = inputs[:,perm]
        
    return inputs, labels

In [5]:
def hot_attn(Q,K,V,temp):
    return torch.softmax(Q@K.T/temp,-1)@V

In [6]:
def mad(x):
    return (x - x.mean(0)).abs().mean(0)

In [7]:
def rescaled_attn_test(support, support_labels, query, query_labels, temp, iterations, scale):
    
    # standardise the combined set
    standard = F.batch_norm(torch.cat((support,query),0),None,None,training=True)
    # split back up
    support, query = standard[:support.size(0)], standard[support.size(0):]
    
    s0, s1 = support[support_labels==0], support[support_labels==1]
    
    for _ in range(iterations):
        s0 = hot_attn(s0,s0,s0,temp)
        s1 = hot_attn(s1,s1,s1,temp)
        
    combined = torch.cat((s0,s1),0)
    rescale = mad(combined)
    rescale = scale * (rescale - rescale.min()) / (rescale.max() - rescale.min())
    
    predictions = hot_attn(rescale*support,rescale*query,support_labels,1.)
    
    accuracy = ((predictions > 0.5) == query_labels)# .sum()/query_labels.size(0)
    
    return accuracy.cpu().numpy()

In [8]:
def no_scale_attn_test(support, support_labels, query, query_labels, standardise=True):
    
    if standardise:
        # standardise the combined set
        standard = F.batch_norm(torch.cat((support,query),0),None,None,training=True)
        # split back up
        support, query = standard[:support.size(0)], standard[support.size(0):]
    
    predictions = hot_attn(query,support,support_labels,1.)
    
    accuracy = ((predictions > 0.5) == query_labels).sum()/query_labels.size(0)
    
    return accuracy

In [9]:
def feature_permute(support, query):
    """
    Permute the features of the support and query sets the same way.
    """
    
    combined = torch.cat((support,query),0)
    
    idx = np.arange(combined.size(1))
    np.random.shuffle(idx)
    perm = torch.tensor(idx)
    combined[:,torch.tensor(np.arange(combined.size(1)))]= combined[:,perm]
    
    return combined[:support.size(0)], combined[support.size(0):]

In [10]:
support, support_labels = xor_noise(5,3,3)
query, query_labels = xor_noise(5,3,3)
support, query = feature_permute(support, query)

In [11]:
rescaled_attn_test(support, support_labels, query, query_labels, 0.5, 5, 2)

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True])

#### Protonets

In [12]:
class Feedforward(torch.nn.Module):
    def __init__(self, input_size, hidden_size, out_dim):
        super(Feedforward, self).__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, out_dim)

    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        return output
    
class Averager():
    def __init__(self):
        self.n = 0
        self.v = 0

    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1

    def item(self):
        return self.v

In [13]:
def euclidean_metric(a, b):
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b)**2).sum(dim=2)
    return logits

def classify_proto(train, test, train_labels, **kwargs):
    # proto = torch_scatter.scatter_mean(train, train_labels.type(torch.int64), dim=0)
    tr0, tr1 = train[train_labels==0], train[train_labels==1]
    
    proto_tr0 = tr0.mean(0)
    proto_tr1 = tr1.mean(0)
    proto = torch.stack((proto_tr0, proto_tr1))

    # Compute predictions and accuracy
    logits = euclidean_metric(test, proto)
    # predictions = torch.softmax(logits, axis=-1)
    # return predictions
    return logits

def train(n, a, i, out_dim, max_epoch=1000, verbose = False, val_tasks = 1000):
    seq_len = a + i
    
    # Set up model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Feedforward(input_size=seq_len, hidden_size=100, out_dim=out_dim)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(1, max_epoch + 1):
        optimizer.zero_grad()
        model.train()

        # Get and reshape data
        support, support_labels = xor_noise(n, a, i)
        query, query_labels = xor_noise(n, a, i)
        support, query = feature_permute(support, query)
        support = support.to(device)
        query = query.to(device)
        support_labels = support_labels.to(device)
        query_labels = query_labels.to(device)

        # Compute prototypes
        support = model(support)
        query = model(query)
        logits = classify_proto(support, query, support_labels)

        # Compute distances and loss
        # one_hot_labels = F.one_hot(query_labels.long(), num_classes=2).float()
        loss = F.cross_entropy(logits, query_labels.long())
        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)
        loss.backward()
        optimizer.step()
        if verbose:
            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))
            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))
        proto = None; logits = None; loss = None
    
    # Validate
    model.eval()
    val_accs = []
    
    for epoch in range(1, val_tasks):
        # Get and reshape data
        support, support_labels = xor_noise(n, a, i)
        query, query_labels = xor_noise(n, a, i)
        support, query = feature_permute(support, query)
        support = support.to(device)
        query = query.to(device)
        support_labels = support_labels.to(device)
        query_labels = query_labels.to(device)

        # Compute prototypes
        support = model(support)
        query = model(query)
        logits = classify_proto(support, query, support_labels)

        # Compute distances and loss
        # one_hot_labels = F.one_hot(query_labels.long(), num_classes=2).float()
        loss = F.cross_entropy(logits, query_labels.long())
        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)
        val_accs.append(acc.cpu().numpy())
        if verbose:
            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))
            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))
        proto = None; logits = None; loss = None
    
    return val_accs

In [14]:
n = 5
for L in [5, 10]:
    for out_dim in [1, L, 2*L, L**2]:
        accs = []
        for a in [2, 3, 4]:
            i = L - a

            # Protonet
            val_acc = train(n, a, i, out_dim)
            print('L: {}, a: {}, out dim: {}, acc: {}'.format(L, a, out_dim, np.mean(val_acc)))
            accs.append(val_acc)
        print('${:.1f} \pm {:.1f}$ & ${:.1f} \pm {:.1f}$ & ${:.1f} \pm {:.1f}$'.format(100*np.mean(accs[0]),
                                                                                       100*np.std(accs[0])/np.sqrt(1000),
                                                                                       100*np.mean(accs[1]),
                                                                                       100*np.std(accs[1])/np.sqrt(1000),
                                                                                       100*np.mean(accs[2]),
                                                                                       100*np.std(accs[2])/np.sqrt(1000)))
        print()


L: 5, a: 2, out dim: 1, acc: 0.564965009689331
L: 5, a: 3, out dim: 1, acc: 0.5538288354873657
L: 5, a: 4, out dim: 1, acc: 0.5982983112335205
$56.5 \pm 0.5$ & $55.4 \pm 0.6$ & $59.8 \pm 0.6$

L: 5, a: 2, out dim: 5, acc: 0.7391892671585083
L: 5, a: 3, out dim: 5, acc: 0.6600350141525269
L: 5, a: 4, out dim: 5, acc: 0.9122872948646545
$73.9 \pm 0.5$ & $66.0 \pm 0.7$ & $91.2 \pm 0.6$

L: 5, a: 2, out dim: 10, acc: 0.857357382774353
L: 5, a: 3, out dim: 10, acc: 0.7481982111930847
L: 5, a: 4, out dim: 10, acc: 0.9999499917030334
$85.7 \pm 0.4$ & $74.8 \pm 0.6$ & $100.0 \pm 0.0$

L: 5, a: 2, out dim: 25, acc: 0.9043042659759521
L: 5, a: 3, out dim: 25, acc: 0.804729700088501
L: 5, a: 4, out dim: 25, acc: 1.0
$90.4 \pm 0.3$ & $80.5 \pm 0.6$ & $100.0 \pm 0.0$

L: 10, a: 2, out dim: 1, acc: 0.5189689993858337
L: 10, a: 3, out dim: 1, acc: 0.5014263987541199
L: 10, a: 4, out dim: 1, acc: 0.5024024248123169
$51.9 \pm 0.3$ & $50.1 \pm 0.2$ & $50.2 \pm 0.2$

L: 10, a: 2, out dim: 10, acc: 0.5709

#### Attention model

In [15]:
for L in [5, 10]:
    accs = []
    for a in [2, 3, 4]:
        i = L - a
        
        vas = []
        for epoch in range(1, 1000):
            #
            support, support_labels = xor_noise(n, a, i)
            query, query_labels = xor_noise(n, a, i)
            support, query = feature_permute(support, query)

            # Attn
            acc = rescaled_attn_test(support, support_labels, query, query_labels, 0.5, 5, 2)
            vas.append(acc)
        
        val_acc = vas
        print('L: {}, a: {}, out dim: {}, acc: {}'.format(L, a, L, np.mean(val_acc)))
        accs.append(val_acc)
    print('${:.3f} \pm {:.3f}$ & ${:.3f} \pm {:.3f}$ & ${:.3f} \pm {:.3f}$'.format(100*np.mean(accs[0]),
                                                                                       100*np.std(accs[0])/np.sqrt(1000),
                                                                                       100*np.mean(accs[1]),
                                                                                       100*np.std(accs[1])/np.sqrt(1000),
                                                                                       100*np.mean(accs[2]),
                                                                                       100*np.std(accs[2])/np.sqrt(1000)))
    print()


L: 5, a: 2, out dim: 5, acc: 0.9954954954954955
L: 5, a: 3, out dim: 5, acc: 1.0
L: 5, a: 4, out dim: 5, acc: 1.0
$99.550 \pm 0.212$ & $100.000 \pm 0.000$ & $100.000 \pm 0.000$

L: 10, a: 2, out dim: 10, acc: 0.763913913913914
L: 10, a: 3, out dim: 10, acc: 0.8317317317317318
L: 10, a: 4, out dim: 10, acc: 0.961498998998999
$76.391 \pm 1.343$ & $83.173 \pm 1.183$ & $96.150 \pm 0.608$



#### Protonet on partitioned XOR

In [16]:
def train_partitioned(n, a, i, s, out_dim, max_epoch=1000, verbose = False, val_tasks = 1000):
    seq_len = a + i
    
    # Set up model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Feedforward(input_size=seq_len, hidden_size=100, out_dim=out_dim)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(1, max_epoch + 1):
        optimizer.zero_grad()
        model.train()

        # Get and reshape data
        support, support_labels = partitioned_xor(n, a, i, s, train=True)
        query, query_labels = partitioned_xor(n, a, i, s, train=True)
        # support, query = feature_permute(support, query)
        support = support.to(device)
        query = query.to(device)
        support_labels = support_labels.to(device)
        query_labels = query_labels.to(device)

        # Compute prototypes
        support = model(support)
        query = model(query)
        logits = classify_proto(support, query, support_labels)

        # Compute distances and loss
        loss = F.cross_entropy(logits, query_labels.long())
        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)
        loss.backward()
        optimizer.step()
        if verbose:
            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))
            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))
        proto = None; logits = None; loss = None
    
    # Validate
    model.eval()
    val_accs = []
    
    for epoch in range(1, val_tasks):
        # Get and reshape data
        support, support_labels = partitioned_xor(n, a, i, s, train=False)
        query, query_labels = partitioned_xor(n, a, i, s, train=False)
        # support, query = feature_permute(support, query)
        support = support.to(device)
        query = query.to(device)
        support_labels = support_labels.to(device)
        query_labels = query_labels.to(device)

        # Compute prototypes
        support = model(support)
        query = model(query)
        logits = classify_proto(support, query, support_labels)

        # Compute distances and loss
        # one_hot_labels = F.one_hot(query_labels.long(), num_classes=2).float()
        loss = F.cross_entropy(logits, query_labels.long())
        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)
        val_accs.append(acc.cpu().numpy())
        if verbose:
            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))
            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))
        proto = None; logits = None; loss = None
    
    return val_accs

In [17]:
n = 5
a = 2
i = 3
s = 3
out_dim = (a + i) ** 2
val_acc = train(n, a, i, out_dim)
val_acc_partitioned = train_partitioned(n, a, i, s, out_dim)

In [20]:
s = 4

for L in [10]:
    for out_dim in [1, L, 2*L, L**2]:
        accs = []
        for a in [2, 3, 4]:
            i = L - a

            # Protonet
            val_acc = train_partitioned(n, a, i, s, out_dim)
            print('L: {}, a: {}, out dim: {}, acc: {}'.format(L, a, out_dim, val_acc))
            accs.append(val_acc)
        print('${:.3f}$ & ${:.3f}$ & ${:.3f}$'.format(accs[0], accs[1], accs[2]))
        print()


L: 10, a: 2, out dim: 1, acc: 0.5089098215103149
L: 10, a: 3, out dim: 1, acc: 0.5016274452209473
L: 10, a: 4, out dim: 1, acc: 0.5018521547317505
$0.509$ & $0.502$ & $0.502$

L: 10, a: 2, out dim: 10, acc: 0.5063568949699402
L: 10, a: 3, out dim: 10, acc: 0.49984949827194214
L: 10, a: 4, out dim: 10, acc: 0.5019532442092896
$0.506$ & $0.500$ & $0.502$

L: 10, a: 2, out dim: 20, acc: 0.5112118124961853
L: 10, a: 3, out dim: 20, acc: 0.4973980486392975
L: 10, a: 4, out dim: 20, acc: 0.5030782222747803
$0.511$ & $0.497$ & $0.503$

L: 10, a: 2, out dim: 100, acc: 0.5072084069252014
L: 10, a: 3, out dim: 100, acc: 0.4991496801376343
L: 10, a: 4, out dim: 100, acc: 0.4998376965522766
$0.507$ & $0.499$ & $0.500$

