In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import tqdm
import torchvision.datasets as datasets
from torch.utils.data import Sampler
import random

In [2]:
data_directory = './'

In [3]:
torchvision.datasets.Flowers102(root=f'{data_directory}', download=True)

Downloading https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz to flowers-102/102flowers.tgz


  0%|          | 0/344862509 [00:00<?, ?it/s]

Extracting flowers-102/102flowers.tgz to flowers-102
Downloading https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat to flowers-102/imagelabels.mat


  0%|          | 0/502 [00:00<?, ?it/s]

Downloading https://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat to flowers-102/setid.mat


  0%|          | 0/14989 [00:00<?, ?it/s]

Dataset Flowers102
    Number of datapoints: 1020
    Root location: ./
    split=train

In [4]:
image_size = 224

In [5]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((image_size, image_size)),
        transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ]
)

In [6]:
training_set = torchvision.datasets.Flowers102(root=f'{data_directory}', split='train', transform=transform)
validation_set = torchvision.datasets.Flowers102(root=f'{data_directory}', split='val', transform=transform)
test_set = torchvision.datasets.Flowers102(root=f'{data_directory}', split='test', transform=transform)

validation_set.get_labels = lambda: [
    instance for instance in validation_set._labels
]

test_set.get_labels = lambda: [
    instance for instance in test_set._labels
]

In [7]:
n_way = 5
k_shot = 3
n_query = 5
n_epochs = 20
n_val_tasks = 100
n_test_tasks = 1000
batch_size = 64

In [8]:
class TaskSampler(Sampler):
    
    def __init__(self, dataset, n_way, k_shot, n_query, n_tasks):
        super().__init__(data_source=None)
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_tasks = n_tasks
        self.n_query = n_query
        
        self.items_per_label = {}
        
        for item, label in enumerate(dataset.get_labels()):
            if label in self.items_per_label.keys():
                self.items_per_label[label].append(item)
            else:
                self.items_per_label[label] = [item]
                
    def __len__(self):
        return self.n_tasks
    
    def __iter__(self):
        for task in range(self.n_tasks):
            yield torch.cat(
                [
                    torch.Tensor(random.sample(self.items_per_label[label], self.k_shot + self.n_query)).type(torch.int)
                    
                    for label in random.sample(self.items_per_label.keys(), self.n_way)
                ]
            ).tolist()
            
    def collate_fn(self, input_data):
        true_class_ids = list({x[1] for x in input_data})

        all_images = torch.cat([x[0].unsqueeze(0) for x in input_data])
        all_images = all_images.reshape(
            (self.n_way, self.k_shot + self.n_query, *all_images.shape[1:])
        )

        all_labels = torch.tensor(
            [true_class_ids.index(x[1]) for x in input_data]
        ).reshape((self.n_way, self.k_shot + self.n_query))

        support_images = all_images[:, : self.k_shot].reshape(
            (-1, *all_images.shape[2:])
        )
        query_images = all_images[:, self.k_shot :].reshape((-1, *all_images.shape[2:]))
        support_labels = all_labels[:, : self.k_shot].flatten()
        query_labels = all_labels[:, self.k_shot :].flatten()

        return support_images, support_labels, query_images, query_labels, true_class_ids

In [9]:
validation_sampler = TaskSampler(validation_set, n_way, k_shot, n_query, n_val_tasks)
test_sampler = TaskSampler(test_set, n_way, k_shot, n_query, n_test_tasks)

train_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size,
                                      shuffle=True)

validation_loader = torch.utils.data.DataLoader(
    validation_set,
    batch_sampler=validation_sampler,
    collate_fn=validation_sampler.collate_fn,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_sampler=test_sampler,
    collate_fn=test_sampler.collate_fn,
)

In [10]:
import torch
import torchvision


class MatchingNetwork(torch.nn.Module):
    
    def __init__(self, backbone=None, image_size=224, use_full_contextual_embedding=True) -> None:
        super().__init__()
        
        self.use_full_contextual_embedding = use_full_contextual_embedding
        
        self.backbone = backbone
        
        if self.backbone is None:
            self.backbone = self.get_backbone()
        
        self.feature_size = self.get_output_shape(self.backbone, image_size)[0]
        
        self.support_encoder = torch.nn.LSTM(
            input_size=self.feature_size,
            hidden_size=self.feature_size,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        
        self.query_encoder = torch.nn.LSTMCell(self.feature_size * 2, self.feature_size)
        
        self.contextualized_support_features = None
        self.one_hot_support_labels = None
        
        self.softmax = torch.nn.Softmax(dim=1)
        
    def get_backbone(self):
        backbone = torchvision.models.resnet18(pretrained=True)
        backbone.fc = torch.nn.Flatten()
        return backbone
                
    def encode_support_set(self, support_images, support_labels):

        support_features = self.backbone(support_images)
        
        if self.use_full_contextual_embedding:
            hidden_state = self.support_encoder(support_features.unsqueeze(0))[0].squeeze(0)
            self.contextualized_support_features = support_features + hidden_state[:, : self.feature_size] + hidden_state[:, self.feature_size :]
            
        else:
            self.contextualized_support_features = support_features

        self.one_hot_support_labels = torch.nn.functional.one_hot(support_labels).float()

    def encode_query_features(self, query_set):
        
        query_features = self.backbone(query_set)

        if not self.use_full_contextual_embedding:
            return query_features
        
        hidden_state = query_features
        cell_state = torch.zeros_like(query_features)

        for _ in range(len(self.contextualized_support_features)):
            attention = self.softmax(
                hidden_state.mm(self.contextualized_support_features.T)
            )
            read_out = attention.mm(self.contextualized_support_features)
            lstm_input = torch.cat((query_features, read_out), 1)

            hidden_state, cell_state = self.query_encoder(
                lstm_input, (hidden_state, cell_state)
            )
            hidden_state = hidden_state + query_features

        return hidden_state
        
    def get_output_shape(self, model, image_size):
        x = torch.randn(1, 3, image_size, image_size)
        out = model(x)
        return out.shape[1:]
    
    def forward(self, support_images, support_labels, query_images):
        
        self.encode_support_set(support_images, support_labels)
        
        contextualized_query_features = self.encode_query_features(
            query_images
        )
        
        similarity_matrix = self.softmax(
            contextualized_query_features.mm(
                torch.nn.functional.normalize(self.contextualized_support_features).T
            )
        )
                
        log_probabilities = (
            similarity_matrix.mm(self.one_hot_support_labels) + 1e-4
        ).log()
        
        return log_probabilities

In [11]:
def evaluate_model(model, criterion, data_loader):
    
    model.eval()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    running_loss = 0.0
    running_correct = 0
    total = 0
    
    for support_images, support_labels, query_images, query_labels, _ in tqdm.tqdm(data_loader):
        
        support_images = support_images.to(device)
        support_labels = support_labels.to(device)
        query_images = query_images.to(device)
        query_labels = query_labels.to(device)

        with torch.no_grad():
            scores = model(support_images, support_labels, query_images)

            loss = criterion(scores, query_labels)

        running_loss += loss.item()
        total += query_labels.shape[0]
        _, preds = torch.max(scores, 1)
        running_correct += torch.sum(preds == query_labels).item()


    print(f'Loss: {running_loss / len(data_loader)}, Accuracy: {running_correct / total}')
    
    return running_loss / len(data_loader), running_correct / total


In [12]:
def train(backbone, fc_layer, matching_network, optimizer, criterion, train_loader, val_loader, num_epochs=100):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    backbone.to(device)
    fc_layer.to(device)
    matching_network.to(device)
    
    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []
    
    for epoch in range(num_epochs):
        
        backbone.train()
        fc_layer.train()
        
        running_loss = 0.0
        running_correct = 0
        total = 0
        
        start_time = time.time()
        
        for images, labels in tqdm.tqdm(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            embedding = backbone(images)
            output = fc_layer(embedding)
            
            loss = criterion(output, labels - 1)
            
            running_loss += loss.item()
            total += labels.shape[0]
            
            _, preds = torch.max(output, 1)
            
            running_correct += torch.sum(preds == labels - 1).item()
            
            loss.backward()
            optimizer.step()
         
        end_time = time.time()
        
        print(f'Epoch: {epoch + 1}, Loss: {running_loss / len(train_loader)}, Accuracy: {running_correct / total}, Time: {(end_time - start_time):.4f}s')
        
        train_loss_history.append(running_loss / len(train_loader))
        train_acc_history.append(running_correct / total)
        
        backbone.eval()
        val_loss, val_acc = evaluate_model(model, criterion, val_loader)
        
        val_loss_history.append(val_loss)
        val_acc_history.append(val_acc)
        
    return train_loss_history, train_acc_history, val_loss_history, val_acc_history

In [13]:
backbone = torchvision.models.resnet18(pretrained=True)
backbone.fc = torch.nn.Flatten()

fully_connected_layer = torch.nn.Linear(in_features=512, out_features=102)

model = MatchingNetwork(image_size=image_size, use_full_contextual_embedding=False, backbone=backbone)

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(list(backbone.parameters()) + list(fully_connected_layer.parameters()), lr=0.001, momentum=0.9)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [14]:
history = train(backbone, fully_connected_layer, model, optimizer, criterion, train_loader, validation_loader, num_epochs=n_epochs)

100%|██████████| 16/16 [00:15<00:00,  1.06it/s]


Epoch: 1, Loss: 4.764286935329437, Accuracy: 0.022549019607843137, Time: 15.1728s


100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


Loss: 0.3531954649090767, Accuracy: 0.9236


100%|██████████| 16/16 [00:09<00:00,  1.76it/s]


Epoch: 2, Loss: 4.381758660078049, Accuracy: 0.05392156862745098, Time: 9.1255s


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]


Loss: 0.35967999875545503, Accuracy: 0.9188


100%|██████████| 16/16 [00:09<00:00,  1.71it/s]


Epoch: 3, Loss: 3.976630836725235, Accuracy: 0.23823529411764705, Time: 9.3508s


100%|██████████| 100/100 [00:36<00:00,  2.75it/s]


Loss: 0.30677311472594737, Accuracy: 0.9424


100%|██████████| 16/16 [00:09<00:00,  1.62it/s]


Epoch: 4, Loss: 3.596532866358757, Accuracy: 0.4549019607843137, Time: 9.8999s


100%|██████████| 100/100 [00:36<00:00,  2.76it/s]


Loss: 0.3192779658362269, Accuracy: 0.9228


100%|██████████| 16/16 [00:09<00:00,  1.72it/s]


Epoch: 5, Loss: 3.2229566127061844, Accuracy: 0.6480392156862745, Time: 9.2926s


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Loss: 0.29269321866333486, Accuracy: 0.9212


100%|██████████| 16/16 [00:09<00:00,  1.63it/s]


Epoch: 6, Loss: 2.8687037229537964, Accuracy: 0.765686274509804, Time: 9.8035s


100%|██████████| 100/100 [00:36<00:00,  2.74it/s]


Loss: 0.2559988336078823, Accuracy: 0.936


100%|██████████| 16/16 [00:09<00:00,  1.72it/s]


Epoch: 7, Loss: 2.5350101739168167, Accuracy: 0.8470588235294118, Time: 9.2829s


100%|██████████| 100/100 [00:36<00:00,  2.71it/s]


Loss: 0.25716664565727115, Accuracy: 0.9304


100%|██████████| 16/16 [00:09<00:00,  1.63it/s]


Epoch: 8, Loss: 2.2473425418138504, Accuracy: 0.8892156862745098, Time: 9.8070s


100%|██████████| 100/100 [00:36<00:00,  2.71it/s]


Loss: 0.22083084305748343, Accuracy: 0.944


100%|██████████| 16/16 [00:09<00:00,  1.73it/s]


Epoch: 9, Loss: 1.9767436236143112, Accuracy: 0.9215686274509803, Time: 9.2725s


100%|██████████| 100/100 [00:36<00:00,  2.75it/s]


Loss: 0.21774466471746565, Accuracy: 0.936


100%|██████████| 16/16 [00:09<00:00,  1.64it/s]


Epoch: 10, Loss: 1.7349561676383018, Accuracy: 0.957843137254902, Time: 9.7915s


100%|██████████| 100/100 [00:36<00:00,  2.74it/s]


Loss: 0.21846646750345827, Accuracy: 0.942


100%|██████████| 16/16 [00:09<00:00,  1.71it/s]


Epoch: 11, Loss: 1.5259630233049393, Accuracy: 0.9735294117647059, Time: 9.3916s


100%|██████████| 100/100 [00:36<00:00,  2.74it/s]


Loss: 0.19907762898132206, Accuracy: 0.9412


100%|██████████| 16/16 [00:09<00:00,  1.71it/s]


Epoch: 12, Loss: 1.3266568258404732, Accuracy: 0.9892156862745098, Time: 9.3895s


100%|██████████| 100/100 [00:37<00:00,  2.68it/s]


Loss: 0.1866032757330686, Accuracy: 0.9508


100%|██████████| 16/16 [00:09<00:00,  1.72it/s]


Epoch: 13, Loss: 1.1695980876684189, Accuracy: 0.9901960784313726, Time: 9.2967s


100%|██████████| 100/100 [00:36<00:00,  2.74it/s]


Loss: 0.1796869955956936, Accuracy: 0.9484


100%|██████████| 16/16 [00:09<00:00,  1.69it/s]


Epoch: 14, Loss: 1.0308128334581852, Accuracy: 0.9911764705882353, Time: 9.4523s


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Loss: 0.19408794251270592, Accuracy: 0.9468


100%|██████████| 16/16 [00:09<00:00,  1.73it/s]


Epoch: 15, Loss: 0.8915959782898426, Accuracy: 0.996078431372549, Time: 9.2815s


100%|██████████| 100/100 [00:36<00:00,  2.76it/s]


Loss: 0.1713088256213814, Accuracy: 0.9516


100%|██████████| 16/16 [00:09<00:00,  1.73it/s]


Epoch: 16, Loss: 0.785067830234766, Accuracy: 0.9990196078431373, Time: 9.2729s


100%|██████████| 100/100 [00:36<00:00,  2.78it/s]


Loss: 0.17697429245337845, Accuracy: 0.9496


100%|██████████| 16/16 [00:09<00:00,  1.63it/s]


Epoch: 17, Loss: 0.6987030543386936, Accuracy: 0.9990196078431373, Time: 9.8207s


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Loss: 0.1670414273161441, Accuracy: 0.9544


100%|██████████| 16/16 [00:09<00:00,  1.71it/s]


Epoch: 18, Loss: 0.6111920587718487, Accuracy: 0.9990196078431373, Time: 9.3647s


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]


Loss: 0.1807084836624563, Accuracy: 0.9492


100%|██████████| 16/16 [00:09<00:00,  1.62it/s]


Epoch: 19, Loss: 0.5428508371114731, Accuracy: 0.9980392156862745, Time: 9.8745s


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]


Loss: 0.14884366361424328, Accuracy: 0.956


100%|██████████| 16/16 [00:09<00:00,  1.68it/s]


Epoch: 20, Loss: 0.48283145017921925, Accuracy: 0.9990196078431373, Time: 9.5054s


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]

Loss: 0.15529732537455856, Accuracy: 0.9516





In [15]:
print('Validation: ', end='')
evaluate_model(model, criterion, validation_loader)

print('Test: ', end='')
evaluate_model(model, criterion, test_loader);

Validation: 

100%|██████████| 100/100 [00:36<00:00,  2.71it/s]


Loss: 0.18227610478177667, Accuracy: 0.9412
Test: 

100%|██████████| 1000/1000 [06:06<00:00,  2.73it/s]

Loss: 0.17386747113103046, Accuracy: 0.94944





In [16]:
torch.save(model.state_dict(), "matching_network_classical.pt")