In [65]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    ExponentialLR,
)

In [66]:
N_SAMPLES_PER_CLASS = 7000
N_CLASSES_TOTAL = 10
TRAIN_SIZE = 0.5 # / NUM_CLASS
N_TRAIN_CLASSES = int(TRAIN_SIZE * N_CLASSES_TOTAL)
N_TEST_CLASSES = int((1 - TRAIN_SIZE) * N_CLASSES_TOTAL)
BATCH_SIZE = 128

N_TRAIN_SUPPORTS = 5
N_TRAIN_QUERIES = 20

N_TEST_SUPPORTS = 5
N_TEST_QUERIES = 20

EMBEDDING_DIM = 128
K = 1
# N_K = 1

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [67]:
# Load CSV files
_train_df = pd.read_csv("/kaggle/input/fashionmnist/fashion-mnist_train.csv")
_test_df = pd.read_csv("/kaggle/input/fashionmnist/fashion-mnist_test.csv")
df = pd.concat([_train_df, _test_df])
df.sort_values('label', inplace=True)

In [68]:
train_df, test_df = train_test_split(df, train_size=TRAIN_SIZE)

In [69]:
class FashionMnistEpisodicDataset(Dataset):
    def __init__(self, dataframe, n_supports=3, n_queries=5, n_episodes=2000,
                 n_total_classes=5, n_samples_per_class=7000):
        
        self.labels = torch.tensor(dataframe.iloc[:, 0].values, dtype=torch.long)  
        self.images = torch.tensor(dataframe.iloc[:, 1:].values, dtype=torch.float32).view(-1, 1, 28, 28) / 255.0

        self.n_supports = n_supports
        self.n_queries = n_queries
        self.n_episodes = n_episodes
        self.n_total_classes = n_total_classes
        self.n_samples_per_class = n_samples_per_class

        # Precompute indices for classes
        self.class_indices = {c: torch.where(self.labels == c)[0] for c in range(n_total_classes)}

    def __len__(self):
        return self.n_episodes 

    def __getitem__(self, _):
        # Sample unique classes for support
        support_classes = np.random.choice(self.n_total_classes, self.n_supports, replace=False)
        query_classes = np.random.choice(support_classes, self.n_queries, replace=True)

        # Efficiently sample indices from each class
        support_indices = torch.cat([self.class_indices[c][torch.randint(0, len(self.class_indices[c]), (1,))] for c in support_classes])
        query_indices = torch.cat([self.class_indices[c][torch.randint(0, len(self.class_indices[c]), (1,))] for c in query_classes])

        # Vectorized indexing
        support_images, support_labels = self.images[support_indices], self.labels[support_indices]
        query_images, query_labels = self.images[query_indices], self.labels[query_indices]

        # onehot_support_labels = F.one_hot(support_labels, num_classes=self.n_total_classes).type(torch.float32)
        # onehot_query_labels = F.one_hot(query_labels, num_classes=self.n_total_classes).type(torch.float32)

        return support_images, support_labels, query_images, query_labels

In [70]:
train_dataset = FashionMnistEpisodicDataset(
    train_df,
    n_supports=N_TRAIN_SUPPORTS,
    n_queries=N_TRAIN_QUERIES,
    n_total_classes=N_TRAIN_CLASSES,
    n_samples_per_class=N_SAMPLES_PER_CLASS,
)

test_dataset = FashionMnistEpisodicDataset(
    test_df,
    n_supports=N_TEST_SUPPORTS,
    n_queries=N_TEST_QUERIES,
    n_total_classes=N_TEST_CLASSES,
    n_samples_per_class=N_SAMPLES_PER_CLASS,
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
)

In [71]:
support_images, support_labels, query_images, query_labels = next(iter(train_dataloader))
support_images.shape, support_labels.shape, query_images.shape, query_labels.shape

(torch.Size([128, 5, 1, 28, 28]),
 torch.Size([128, 5]),
 torch.Size([128, 20, 1, 28, 28]),
 torch.Size([128, 20]))

In [72]:
# Expect x of shape Batch N, Eps N, C, W, H
class ConvEmbedding(nn.Module):
    def __init__(self, embedding_dim=32):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1,32,kernel_size=3, bias=False),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2,2)),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size=3, bias=False),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2,2)),
            nn.ReLU(),
        )
        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fully_connected = nn.Linear(64,embedding_dim)

    def forward(self, x):
        N_B, Eps_Size, C, W, H = x.shape
        x = x.view(N_B * Eps_Size, C, W, H)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.adaptive_avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fully_connected(x)
        x = x.view(N_B, Eps_Size, -1)
        return x

In [73]:
# Appendix: A.2
# _g: embedded supports: N, S, E
# N:  Batch size
# S:  Support size
# E:  Embedding dim
class SupportFullyConditionalEmbedding(nn.Module):
    def __init__(self, embedding_dim=32):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.bi_lstm = nn.LSTM(input_size=embedding_dim,
                               hidden_size=embedding_dim,
                               num_layers=7,
                               bidirectional=True, 
                               batch_first=True)

    def forward(self, _g):
        output, _ = self.bi_lstm(_g)                                     # N, S, 2E
        # Using lstm output because it contains hidden state for every timestep. 
        # TODO: verify this later
        h_fwd, h_bwd = torch.split(output, self.embedding_dim, dim = 2)  # N, S, E
        g = h_fwd + h_bwd + _g                                           # N, S, E
        return g

In [74]:
# Appendix: A.1
# g:  conditionally embedded supports: N, S, E
# _f: embedded queries: N, Q, E
# N:  Batch size
# S:  Support size
# Q:  Query size
# E:  Embedding dim
class QueryFullyConditionalEmbedding(nn.Module):
    def __init__(self, embedding_dim=32, K=10):
        super().__init__()
        self.K = K
        self.embedding_dim=embedding_dim
        self.lstm = nn.LSTMCell(input_size=embedding_dim,
                                hidden_size=embedding_dim * 2)
        self.projection = nn.Linear(embedding_dim*2, embedding_dim)

    def forward(self, g, _f):
        N, S, E = g.shape                                      # N, S, E
        _, Q, _ = _f.shape                                     # N, Q, E
        
        h0 = torch.zeros(N, Q, E, device=DEVICE)               # N, Q, E
        c0 = torch.zeros(N, Q, 2*E, device=DEVICE)             # N, Q, 2E

        hk, ck = h0, c0
        hk_list = []
        
        for k in range(self.K):
            att = F.softmax(torch.bmm(
                hk,                                            # N, Q, E
                g.transpose(1,2)                               # N, E, S
            ), dim=2)                                          # N, Q, S TODO: verify softmax dim
            #                             (N,Q,S,E) * (N,Q,S,1) = N, Q, S, E
            attw = g.unsqueeze(1).expand(-1, Q, -1, -1) * att.unsqueeze(-1)
            
            rk = torch.sum(attw, dim=2)                        # N, Q, E
            hk_rk = torch.concat([hk, rk], dim=2)              # N, Q, 2E
            
            _h, _c = self.lstm(_f.view(N*Q, E), (              # N*Q,  E
                            hk_rk.view(N*Q, 2*E),              # N*Q,  2E
                               ck.view(N*Q, 2*E)))             # N*Q,  2E

            ck = _c.view(N, Q, 2*E)                            # N, Q, 2E
            hk = self.projection(_h).view(N, Q, E) + _f        # N, Q, E
            hk_list.append(hk.unsqueeze(0))                    # 1, N, Q, E

        hk_list = torch.concat(hk_list, dim=0)                 # K, N, Q, E
        return hk_list[-1]                                     # N, Q, E
        

In [75]:
# eq (1)
# g:  conditionally embedded supports: N, S, E
# f:  conditionally embedded queries: N, Q, E
# N:  Batch size
# S:  Support size
# Q:  Query size
# E:  Embedding dim
class AttentionKernel(nn.Module):
    def __init__(self):
        super().__init__()
        pass
        
    def forward(self, f, g):
        norm_f = F.normalize(f, dim=2)
        norm_g = F.normalize(g, dim=2)
        cosine_similarity = torch.bmm(norm_f, norm_g.transpose(1, 2))
        return F.softmax(cosine_similarity, dim=2)

In [76]:
class MatchingNet(nn.Module):
    def __init__(self, C, embedding_dim=32, K=10):
        super().__init__()
        self.conv_embedding = ConvEmbedding(embedding_dim)
        self.support_fce = SupportFullyConditionalEmbedding(embedding_dim)
        self.query_fce = QueryFullyConditionalEmbedding(embedding_dim, K)
        self.a = AttentionKernel()
        self.C = C
    
    def forward(self, supports, queries, support_labels):
        embedded_supports = self.conv_embedding(supports)
        embedded_queries = self.conv_embedding(queries)

        conditioned_embedded_supports = self.support_fce(embedded_supports)
        conditioned_embedded_queries = self.query_fce(conditioned_embedded_supports, embedded_queries)
        
        attw = self.a(conditioned_embedded_queries, conditioned_embedded_supports)
        onehot_support_labels = F.one_hot(support_labels, num_classes=self.C).type(torch.float32)
        return torch.bmm(attw, onehot_support_labels)

In [77]:
class LogLoss(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C

    # pred:   N, Q, C
    # target: N, Q
    def forward(self, pred, target):
        return F.cross_entropy(pred.view(-1, self.C), target.view(-1))

In [78]:
EPOCHS = 500

model = MatchingNet(N_TRAIN_CLASSES, EMBEDDING_DIM, K)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.to(DEVICE)

train_criterion = LogLoss(N_TRAIN_CLASSES)
test_criterion = LogLoss(N_TEST_CLASSES)
optimizer = optim.AdamW(model.parameters())
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS,
    eta_min=1e-3
)

In [None]:
def train(model, dataloader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    total_batches = len(dataloader)
    for support_images, support_labels, query_images, query_labels in dataloader:
        
        support_images = support_images.to(device, non_blocking=True)
        support_labels = support_labels.to(device, non_blocking=True)
        query_images = query_images.to(device, non_blocking=True)
        query_labels = query_labels.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        pred = model.forward(support_images, query_images, support_labels)
        loss = criterion(pred, query_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.detach()
        # if i % 20 == 0:
        #     print(f"Batch: {i}/{total_batches}", end='\r')
            
    scheduler.step()
    torch.cuda.synchronize() 
    return (total_loss / total_batches).item()
    
@torch.no_grad()
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_batches = len(dataloader)
    for support_images, support_labels, query_images, query_labels in dataloader:
        
        support_images = support_images.to(device, non_blocking=True)
        support_labels = support_labels.to(device, non_blocking=True)
        query_images = query_images.to(device, non_blocking=True)
        query_labels = query_labels.to(device, non_blocking=True)
        
        pred = model.forward(support_images, query_images, support_labels)
        loss = criterion(pred, query_labels)
        total_loss += loss.detach()
        
        # if i % 20 == 0:
        # print(f"Batch: {i}/{total_batches}", end='\r')
            
    torch.cuda.synchronize() 
    return (total_loss / total_batches).item()

total_train_losses = []
total_test_losses = []
for i in range(EPOCHS):
    train_loss = train(model, train_dataloader, train_criterion, optimizer, scheduler, DEVICE)
    test_loss = test(model, test_dataloader, test_criterion, DEVICE)
    
    total_train_losses.append(train_loss)
    total_test_losses.append(test_loss)
    
    print(f'Epoch {i:02d}', end='')
    print(f'\t\t Train: {train_loss}', end='')
    print(f'\t Test: {test_loss}')
    # print(f'\t Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

Epoch 00		 Train: 1.5485713481903076	 Test: 1.5365480184555054
Epoch 01		 Train: 1.5086311101913452	 Test: 1.540629267692566
Epoch 02		 Train: 1.4849272966384888	 Test: 1.4928256273269653
Epoch 03		 Train: 1.4697266817092896	 Test: 1.4702166318893433
Epoch 04		 Train: 1.4630653858184814	 Test: 1.4610432386398315
Epoch 05		 Train: 1.4584702253341675	 Test: 1.4629297256469727
Epoch 06		 Train: 1.4565799236297607	 Test: 1.4563027620315552
Epoch 07		 Train: 1.4528082609176636	 Test: 1.4546213150024414
Epoch 08		 Train: 1.4511349201202393	 Test: 1.4531564712524414
Epoch 09		 Train: 1.4482461214065552	 Test: 1.453055500984192
Epoch 10		 Train: 1.4458609819412231	 Test: 1.4500550031661987
Epoch 11		 Train: 1.4408577680587769	 Test: 1.4413583278656006
Epoch 12		 Train: 1.4388175010681152	 Test: 1.4426851272583008
Epoch 13		 Train: 1.436905860900879	 Test: 1.4383854866027832
Epoch 14		 Train: 1.4352446794509888	 Test: 1.4424903392791748
Epoch 15		 Train: 1.432035207748413	 Test: 1.4412919282913

In [None]:
plt.plot(total_train_losses)
plt.plot(total_test_losses)
plt.show()