In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
N_SAMPLES_PER_CLASS = 7000
N_CLASSES_TOTAL = 10
TRAIN_SIZE = 0.5 # / NUM_CLASS
N_TRAIN_CLASSES = int(TRAIN_SIZE * N_CLASSES_TOTAL)
N_TEST_CLASSES = int((1 - TRAIN_SIZE) * N_CLASSES_TOTAL)
BATCH_SIZE = 32

N_SUPPORTS = 3
N_QUERIES = 5

EMBEDDING_DIM = 32
# N_K = 1

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Load CSV files
_train_df = pd.read_csv("/kaggle/input/fashionmnist/fashion-mnist_train.csv")
_test_df = pd.read_csv("/kaggle/input/fashionmnist/fashion-mnist_test.csv")
df = pd.concat([_train_df, _test_df])
df.sort_values('label', inplace=True)

In [4]:
train_df = df[:N_TRAIN_CLASSES * N_SAMPLES_PER_CLASS]
test_df = df[N_TRAIN_CLASSES * N_SAMPLES_PER_CLASS:]

In [5]:
class FashionMnistEpisodicDataset(Dataset):
    def __init__(self, dataframe, n_supports=3, n_queries=5, n_episodes=1000,
                 n_total_classes=5, n_samples_per_class=7000, is_train=False):
        
        self.labels = torch.tensor(dataframe.iloc[:, 0].values, dtype=torch.long)  
        self.images = torch.tensor(dataframe.iloc[:, 1:].values, dtype=torch.float32).view(-1, 1, 28, 28) / 255.0

        self.n_supports = n_supports
        self.n_queries = n_queries
        self.n_episodes = n_episodes
        self.n_total_classes = n_total_classes
        self.n_samples_per_class = n_samples_per_class
        self.is_train = is_train

        # Precompute indices for classes
        self.class_indices = {c: torch.where(self.labels == c)[0] for c in range(n_total_classes)}

    def __len__(self):
        return self.n_episodes 

    def __getitem__(self, _):
        # Sample unique classes for support
        support_classes = np.random.choice(self.n_total_classes, self.n_supports, replace=False)
        query_classes = np.random.choice(support_classes, self.n_queries, replace=True)

        # Efficiently sample indices from each class
        support_indices = torch.cat([self.class_indices[c][torch.randint(0, len(self.class_indices[c]), (1,))] for c in support_classes])
        query_indices = torch.cat([self.class_indices[c][torch.randint(0, len(self.class_indices[c]), (1,))] for c in query_classes])

        # Vectorized indexing
        support_images, support_labels = self.images[support_indices], self.labels[support_indices]
        query_images, query_labels = self.images[query_indices], self.labels[query_indices]

        return support_images, support_labels, query_images, query_labels

In [6]:
train_dataset = FashionMnistEpisodicDataset(
    train_df,
    n_supports=N_SUPPORTS,
    n_queries=N_QUERIES,
    n_total_classes=N_TRAIN_CLASSES,
    n_samples_per_class=N_SAMPLES_PER_CLASS,
)

train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True
)

In [8]:
# Expect x of shape Batch N, Eps N, C, W, H
class ConvEmbedding(nn.Module):
    def __init__(self, embedding_dim=32):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1,16,kernel_size=3, bias=False),
            nn.BatchNorm2d(16),
            nn.MaxPool2d((2,2)),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,kernel_size=3, bias=False),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2,2)),
            nn.ReLU(),
        )
        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fully_connected = nn.Linear(32,embedding_dim)

    def forward(self, x):
        N_B, Eps_Size, C, W, H = x.shape
        x = x.view(N_B * Eps_Size, C, W, H)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.adaptive_avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fully_connected(x)
        x = x.view(N_B, Eps_Size, -1)
        return x

In [235]:
# Appendix: A.2
# _g: embedded supports: N, S, E
# N:  Batch size
# S:  Support size
# E:  Embedding dim
class SupportFullyConditionalEmbedding(nn.Module):
    def __init__(self, embedding_dim=32):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.bi_lstm = nn.LSTM(input_size=embedding_dim,
                               hidden_size=embedding_dim,
                               bidirectional=True, 
                               batch_first=True)

    def forward(self, _g):
        output, _ = self.bi_lstm(_g)                                     # N, S, 2E
        # Using lstm output because it contains hidden state for every timestep. 
        # TODO: verify this later
        h_fwd, h_bwd = torch.split(output, self.embedding_dim, dim = 2)  # N, S, E
        g = h_fwd + h_bwd + _g                                           # N, S, E
        return g

In [241]:
# Appendix: A.1
# g:  conditionally embedded supports: N, S, E
# _f: embedded queries: N, Q, E
# N:  Batch size
# S:  Support size
# Q:  Query size
# E:  Embedding dim
class QueryFullyConditionalEmbedding(nn.Module):
    def __init__(self, embedding_dim=32, K=3):
        super().__init__()
        self.K = K
        self.embedding_dim=embedding_dim
        self.lstm = nn.LSTMCell(input_size=embedding_dim,
                                hidden_size=embedding_dim * 2)
        self.projection = nn.Linear(embedding_dim*2, embedding_dim)

    def forward(self, g, _f):
        N, S, E = g.shape                                      # N, S, E
        _, Q, _ = _f.shape                                     # N, Q, E
        
        h0 = torch.zeros(N, Q, E)                              # N, Q, E
        c0 = torch.zeros(N, Q, 2*E)                            # N, Q, 2E

        hk, ck = h0, c0
        hk_list = []
        
        for k in range(self.K):
            att = F.softmax(torch.bmm(
                hk,                                            # N, Q, E
                g.transpose(1,2)                               # N, E, S
            ), dim=1)                                          # N, Q, S
            #                             (N,Q,S,E) * (N,Q,S,1) = N, Q, S, E
            attw = g.unsqueeze(1).expand(-1, Q, -1, -1) * att.unsqueeze(-1)
            
            rk = torch.sum(attw, dim=2)                        # N, Q, E
            hk_rk = torch.concat([hk, rk], dim=2)              # N, Q, 2E
            
            _h, _c = self.lstm(_f.view(N*Q, E), (              # N*Q,  E
                            hk_rk.view(N*Q, 2*E),              # N*Q,  2E
                               ck.view(N*Q, 2*E)))             # N*Q,  2E

            ck = _c.view(N, Q, 2*E)                            # N, Q, 2E
            hk = projection(_h).view(N, Q, E) + _f             # N, Q, E
            hk_list.append(hk.unsqueeze(0))                    # 1, N, Q, E

        hk_list = torch.concat(hk_list, dim=0)                 # K, N, Q, E
        return hk_list[-1]                                     # N, Q, E
        

In [242]:
support_images, support_labels, query_images, query_labels = next(iter(train_dataloader))
support_images.shape, support_labels.shape, query_images.shape, query_labels.shape

(torch.Size([32, 3, 1, 28, 28]),
 torch.Size([32, 3]),
 torch.Size([32, 5, 1, 28, 28]),
 torch.Size([32, 5]))

In [243]:
embedded_supports = ConvEmbedding()(support_images)
embedded_queries = ConvEmbedding()(query_images)

In [244]:
embedded_supports.shape

torch.Size([32, 3, 32])

In [245]:
conditioned_embedded_supports = SupportFullyConditionalEmbedding()(embedded_supports)
conditioned_embedded_supports.shape

OUTPUT:  torch.Size([32, 3, 64])


torch.Size([32, 3, 32])

In [246]:
conditioned_embedded_queries = QueryFullyConditionalEmbedding()(conditioned_embedded_supports, embedded_queries)
conditioned_embedded_queries.shape

torch.Size([32, 5, 32])