In [46]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
import torch.optim as optim
from torchvision import transforms,utils
from torch.utils.data import DataLoader,Dataset
from PIL import Image

print(torch.cuda.is_available())

True


In [47]:
class CIFAR(Dataset) :

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.dir = os.listdir(self.root_dir)
        self.transform = transform
    
    def __len__(self) :
        return len(self.dir)
    
    def __getitem__(self, idx) :
        if torch.is_tensor(idx) :
            idx = idx.tolist()
        images = os.listdir(os.path.join(self.root_dir, self.dir[idx]))
        input_image = None
        support_image = []
        input_label = None
        support_label = []
        for image in images :
            temp = cv2.imread(os.path.join(self.root_dir, self.dir[idx], image))
            temp = Image.fromarray(temp)
            if self.transform :
                temp = self.transform(temp)
            if image[2] == 'I' :
                input_image = temp
                input_label = int(image[0])
            else :
                support_image.append(temp)
                support_label.append(int(image[0]))
        input_label = torch.Tensor(input_label)
        support_label = torch.Tensor(np.asarray(support_label))
        support_image = torch.Tensor(support_image).permute(0,3,1,2)

        sample = {"input_image":input_image,"input_label":input_label,"support_image":input_image,"support_label":support_label}

        return sample

        

In [48]:
dataset = CIFAR('data',transform=transforms.Compose([transforms.Resize(32),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))

In [49]:
testdata = CIFAR('test_data',transform=transforms.Compose([transforms.Resize(32),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))

In [50]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(testdata, batch_size=16, shuffle=True, num_workers=4)

In [51]:
class Network(nn.Module) :
    def __init__(self):
        super(Network, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.MaxPool2d(4,2,1),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
        )
        self.fc1 = nn.Linear(16*4*4,64)
        self.fc2 = nn.Linear(64,32)
    
    def forward(self, x):
        x = main(X)
        x = x.view(-1,num_flat_features(x))
        x = self.fc2(F.relu(self.fc1(x)))
        x = x / (torch.norm(x.view(x.size(0),-1), p=2, dim=1)+1e-6)
        return x
    
    def num_flat_features(self,x) :
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *=s
        return num_features

In [52]:
F = Network().to(device)
G = Network().to(device)
optimF = optim.Adam(F.parameters())
optimG = optim.Adam(G.parameters())

crtierion = nn.CosineEmbeddingLoss()
sim = nn.CosineSimilarity()

In [53]:

num_epochs = 1
losses = []
for epoch in range(num_epochs) :
    F.train()
    G.train()
    batch_loss = []
    for id, (input) in enumerate(train_loader) :
        optimF.zero_grad()
        optimG.zero_grad()
        input_image = input["input_image"].to(device)
        input_label = input["input_label"].to(device)
        support_image = input["support_image"].to(device)
        support_label = input["support_label"].to(device)
        batch_size = support_label.size(0)
        support_image = support_image.view(-1,3,32,32)
        support_embed = G(support_image).view(batch_size,-1,16)
        input_embed = F(input_image).expand(-1,support_embed.size(1),-1)
        target = torch.Tensor((support_label==input_label), dtpye=torch.int32)*2-1
        loss = crtierion(input_embed, support_embed, target)
        loss.backward()
        optimF.step()
        optimG.step()
        batch_loss.append(loss)
    print("{}: Average Loss: {}".format(epoch, batch_loss/len(batch_loss)))
    model.eval()
    correct = 0.0
    total = 0.0
    for id, (input) in enumerate(train_loader) :
        input_image = input["input_image"].to(device)
        input_label = input["input_label"].to(device)
        support_image = input["support_image"].to(device)
        support_label = input["support_label"].to(device)
        batch_size = support_label.size(0)
        support_image = support_image.view(-1,3,32,32)
        support_embed = G(support_image).view(batch_size,-1,16)
        input_embed = F(input_image).expand(-1,support_embed.size(1),-1)
        _, indices = torch.max(sim(input_embed, support_embed),dim=1)
        result = support_label[indices]
        correct += (result==input_label).sum().item()
        total += batch_size
    print("Train Accuracy: {}".format(correct/total))

    model.eval()
    correct = 0.0
    total = 0.0
    for id, (input) in enumerate(test_loader) :
        input_image = input["input_image"].to(device)
        input_label = input["input_label"].to(device)
        support_image = input["support_image"].to(device)
        support_label = input["support_label"].to(device)
        batch_size = support_label.size(0)
        support_image = support_image.view(-1,3,32,32)
        support_embed = G(support_image).view(batch_size,-1,16)
        input_embed = F(input_image).expand(-1,support_embed.size(1),-1)
        _, indices = torch.max(sim(input_embed, support_embed),dim=1)
        result = support_label[indices]
        correct += (result==input_label).sum().item()
        total += batch_size
    print("Test Accuracy: {}".format(correct/total))


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/sriyash/anaconda3/envs/tfgpu/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sriyash/anaconda3/envs/tfgpu/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/sriyash/anaconda3/envs/tfgpu/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/sriyash/anaconda3/envs/tfgpu/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/sriyash/anaconda3/envs/tfgpu/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 6 in dimension 1 at /opt/conda/conda-bld/pytorch_1579022034529/work/aten/src/TH/generic/THTensor.cpp:612
