In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
import torch.optim as optim
from torchvision import transforms,utils
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import tqdm
print(torch.cuda.is_available())

True


In [9]:
class CIFAR(Dataset) :

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.dir = os.listdir(self.root_dir)
        self.transform = transform
        self.data = []
        for k in tqdm.trange(len(self.dir)) :
            i = self.dir[k]
            images = os.listdir(os.path.join(self.root_dir, i))
            input_image = None
            support_image = []
            input_label = None
            support_label = []
            for image in images :
                temp = cv2.imread(os.path.join(self.root_dir, i, image))
                temp = Image.fromarray(temp)
                if self.transform :
                    temp = self.transform(temp)
                if image[2] == 'I' :
                    input_image = temp
                    input_label = int(image[0])
                else :
                    support_image.append(temp)
                    support_label.append(int(image[0]))
            # input_label = torch.Tensor(input_label)
            support_label = torch.Tensor(np.asarray(support_label))
            support_image = torch.from_numpy(np.stack(support_image))
            sample = {"input_image":input_image,"input_label":input_label,"support_image":support_image,"support_label":support_label}
            self.data.append(sample)
    
    def __len__(self) :
        return len(self.data)
    
    def __getitem__(self, idx) :
        if torch.is_tensor(idx) :
            idx = idx.tolist()
        
        return self.data[idx]

In [10]:
dataset = CIFAR('data',transform=transforms.Compose([transforms.Resize(32),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))

100%|██████████| 50000/50000 [01:14<00:00, 672.50it/s]


In [11]:
testdata = CIFAR('test_data',transform=transforms.Compose([transforms.Resize(32),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))

100%|██████████| 5000/5000 [00:08<00:00, 608.79it/s]


In [51]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(testdata, batch_size=32, shuffle=True, num_workers=4)

In [52]:
class Network(nn.Module) :
    def __init__(self):
        super(Network, self).__init__()
        self.main = nn.Sequential(
            
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
        )
        self.fc1 = nn.Linear(32*4*4,64)
        self.drop = nn.Dropout(p=0.33)
        self.fc2 = nn.Linear(64,16)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.main(x)
        x = x.view(-1,self.num_flat_features(x))
        x = self.fc2(self.drop(self.relu(self.fc1(x))))
        x = x / torch.sqrt(torch.sum(x**2,dim=1).unsqueeze(1))
        return x
    
    def num_flat_features(self,x) :
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *=s
        return num_features

In [53]:
F = Network().to(device)
# G = Network().to(device)
optimF = optim.Adam(F.parameters(), lr=0.005)
# optimG = optim.Adam(G.parameters(), lr=0.005)

crtierion = nn.BCELoss()
softmax = nn.Softmax(dim=1)
sim = nn.CosineSimilarity(dim=2)

In [50]:
num_epochs = 100
losses = []
for epoch in range(num_epochs) :
    F.train()
    # G.train()
    batch_loss = []
    for id, (input) in enumerate(train_loader) :
        optimF.zero_grad()
        # optimG.zero_grad()
        input_image = input["input_image"].to(device)
        input_label = input["input_label"].to(device).unsqueeze(1).expand(-1,4)
        support_image = input["support_image"].to(device)
        support_label = input["support_label"].to(device)
        batch_size = support_label.size(0)
        support_image = support_image.view(-1,3,32,32)
        support_embed = F(support_image).view(batch_size,-1,16)
        input_embed = F(input_image).unsqueeze(1).expand(-1,support_embed.size(1),-1)
        target = ((support_label==input_label).float()).detach()
        output = softmax(sim(support_embed, input_embed))
        loss = crtierion(output,target)
        loss.backward()
        optimF.step()
        # optimG.step()
        batch_loss.append(loss)
    print("{}: Average Loss: {}".format(epoch, sum(batch_loss)/len(batch_loss))),
    
    correct = 0.0
    total = 0.0
    F.eval()
    # G.eval()
    for id, (input) in enumerate(train_loader) :
        input_image = input["input_image"].to(device)
        input_label = input["input_label"].to(device).unsqueeze(1).expand(-1,4)
        support_image = input["support_image"].to(device)
        support_label = input["support_label"].to(device)
        batch_size = support_label.size(0)
        support_image = support_image.view(-1,3,32,32)
        support_embed = F(support_image).view(batch_size,-1,16)
        input_embed = F(input_image).unsqueeze(1).expand(-1,support_embed.size(1),-1)
        target = torch.argmax((support_label==input_label).float(),dim=1)
        output = torch.argmax(softmax(sim(support_embed, input_embed)),dim=1)
        correct += (target==output).sum().item()
        total += batch_size
    print("Train Accuracy: {}".format(correct/total)),
    correct = 0.0
    total = 0.0
    for id, (input) in enumerate(test_loader) :
        input_image = input["input_image"].to(device)
        input_label = input["input_label"].to(device).unsqueeze(1).expand(-1,4)
        support_image = input["support_image"].to(device)
        support_label = input["support_label"].to(device)
        batch_size = support_label.size(0)
        support_image = support_image.view(-1,3,32,32)
        support_embed = F(support_image).view(batch_size,-1,16)
        input_embed = F(input_image).unsqueeze(1).expand(-1,support_embed.size(1),-1)
        target = torch.argmax((support_label==input_label).float(),dim=1)
        output = torch.argmax(softmax(sim(support_embed, input_embed)),dim=1)
        correct += (target==output).sum().item()
        total += batch_size
    print("Test Accuracy: {}".format(correct/total))

KeyboardInterrupt: 

In [None]:
path = os.path.join(os.getcwd,"result.pth")
params = {
            'F': F.state_dict(),
            'G': G.state_dict(),
        }
torch.save(params, path)