In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from tqdm import tqdm


In [3]:
class prototype(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(prototype, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        dists = torch.cdist(z_query, z_proto)
        scores = -dists
        return scores


convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
print(convolutional_network)

model = prototype(convolutional_network).cuda()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Data

In [14]:
from matplotlib import pyplot as plt
from PIL import Image
import torch
import torchvision.transforms as transforms
import os

image_size = 256

def get_img(path):
    image = Image.open(path)
    totensor = transforms.ToTensor()
    transform=transforms.Compose(
        [
            # Omniglot images have 1 channel, but our model will expect 3-channel images
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
        ]
    )
    transformed = transform(image)
    # transformed.show()
    return totensor(transformed)[:3]

def get_imgs(path):
    tensor = []
    imgs = os.listdir(path)
    for img in imgs:
        img_path = os.path.join(path,img)
        tensor.append(get_img(img_path))
    return torch.stack(tensor)

def train_test(tensor,split=0.8):
    idx = int(tensor.shape[0] * split)
    return tensor[:idx],tensor[idx:]

def construct(lsts):
    imgs = torch.concatenate(lsts)
    labels = []
    cnt = 0 
    for val in lsts:
        labels += [cnt for i in range(len(val))]
        cnt += 1
    labels = torch.Tensor(labels)
    return imgs,labels
    


BASE = r"C:\Users\victo\Desktop\Files\Tech\Code\Python\neurips\fewshot\mpdata"
bead = get_imgs(os.path.join(BASE,"Bead")).cuda()
fiber = get_imgs(os.path.join(BASE,"Fiber")).cuda()
fragment = get_imgs(os.path.join(BASE,"Fragment")).cuda()
negative = get_imgs(r"C:\Users\victo\Desktop\Files\Tech\Code\Python\neurips\negative").cuda()

bead_train,bead_test = train_test(bead)
fiber_train,fiber_test = train_test(fiber)
fragment_train,fragment_test = train_test(fragment)
negative_train,negative_test = train_test(negative)

labels = []

test_images,test_labels = construct([negative_test,bead_test,fiber_test,fragment_test])
train_images,train_labels = construct([negative_train,bead_train,fiber_train,fragment_train])

scores = model(
    train_images.cuda(),
    train_labels.cuda(),
    test_images.cuda()
)

_, pred_labels = torch.max(scores.data, 1)

print("Ground Truth / Predicted")
total = 0
for i in range(len(pred_labels)):
    if pred_labels[i] == test_labels[i]:
        total += 1
    # print(pred_labels[i],test_labels[i]
print(total/len(pred_labels))

Ground Truth / Predicted
0.918918918918919


In [9]:
negative = get_imgs(r"C:\Users\victo\Desktop\Files\Tech\Code\Python\neurips\negative").cuda()

In [None]:
example_scores = model(
    example_support_images.cuda(),
    example_support_labels.cuda(),
    example_query_images.cuda(),
).detach()