In [1]:
import glob
import numpy as np
import sklearn
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
import os
import PIL.Image as Image
from torchvision import transforms
from torchvision.transforms import Resize, Normalize,Compose, ToTensor,RandomCrop,RandomHorizontalFlip,ColorJitter,Compose
import random
from torch.optim.lr_scheduler import StepLR
import math

from sklearn.preprocessing import normalize
from sklearn.linear_model import ElasticNet
import scipy
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA

import torch.nn.functional as F
# some parts of code is inspired by Zhang et.al.

In [2]:

torch.manual_seed(2022)
np.random.seed(2022)
random.seed(2022)

In [3]:
#Initialization
main_path='C:/Users/sanaz/Desktop/OSLOMET/Semester 2/Data mining/final project/cifar100/'
splits_path = main_path+'splits/bertinetto/'
data_path=main_path+'data/'

In [4]:
def read_classes(split):
    file_path=os.path.join(splits_path,split)
    with open(file_path) as f:
        return f.read().splitlines() 

train_classes=read_classes('train.txt')
val_classes=read_classes('val.txt')
test_classes=read_classes('test.txt')

In [5]:
print(len(train_classes))
print(len(val_classes))
print(len(test_classes))

64
16
20


In [6]:
def read_paths(classes):
    paths=[]
    labels=[]
    for i,class_name in enumerate(classes):

        for img_path in glob.glob(data_path +class_name+ '/*'):
            paths.append(img_path)
            labels.append(i)
    return paths,labels
            

# start

In [7]:
class base_dataset(Dataset):

    def __init__(self, main_path,training,size):
        
        all_train_classes=read_classes('train.txt')
        train_classes=all_train_classes[:20].copy()
        train_paths,train_labels=read_paths(train_classes)
        train_tuples=list(zip(train_paths,train_labels))
        random.shuffle(train_tuples)

        train_paths,train_labels = zip(*train_tuples)
        num=int(len(train_paths)*0.9)
        
        
        if training:
            self.paths=train_paths[:num]
            self.labels=train_labels[:num]
            self.transform=Compose([ Resize((size, size)),
                                                RandomCrop(size, padding=8),
                                                ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                                                RandomHorizontalFlip(),
                                                ToTensor(),
                                                Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                                                ])
        else:
            self.paths=train_paths[num:]
       
            self.labels=train_labels[num:]
            self.transform=Compose([ Resize((size, size)),
                                                ToTensor(),
                                                Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                                ])
        self.length=len(self.paths)
    def __getitem__(self, idx): 
        
        img=Image.open(self.paths[idx]).convert('RGB')
        img=self.transform(img)
        return img,self.labels[idx]
    def __len__(self):
        return self.length
        

In [8]:
class test_dataset(Dataset):

    def __init__(self, size):
        
        all_test_classes=read_classes('test.txt')
        self.test_classes=all_test_classes[:10].copy()
        self.test_paths,self.test_labels=read_paths(test_classes)
        self.size=size
        self.length=len(self.test_paths)
        self.transform = transforms.Compose([
                                               Resize((size, size)),
                                               ToTensor(),
                                               Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        
    def __getitem__(self, idx): 
        if idx == -1:
            return torch.zeros([3, self.size, self.size]), 0
        img=Image.open(self.test_paths[idx]).convert('RGB')
        img=self.transform(img)
        return img,1
    def __len__(self):
        return self.length
    
     

In [9]:

class testSampler():

    def __init__(self, lbl, batch_no ):
        self.batch_no = batch_no
        self.class_no = 5
        self.no_per_class = 35
        self.number_distract = 15

        lbl = np.array(lbl)
        self.ind = []
        for i in range(max(lbl) + 1):
            ind = np.argwhere(lbl == i).reshape(-1)
            ind = torch.from_numpy(ind)
            self.ind.append(ind)

    
    def __iter__(self):
        for b in range(self.batch_no):
            
            classes = torch.randperm(len(self.ind))
            batch_ind = []
            batch = []
            
            for c in classes[:self.class_no]:
                lb = self.ind[c]
                pos = torch.randperm(len(lb))[:self.no_per_class]
                batch_cls = lb[pos]
                ind_class = np.zeros(self.no_per_class)
                ind_class[:batch_cls.shape[0]] = 1
                if batch_cls.shape[0] != self.no_per_class:
                    batch_cls = torch.cat([batch_cls, -1*torch.ones([self.no_per_class-batch_cls.shape[0]]).long()], 0)
                batch.append(batch_cls)
                batch_ind.append(ind_class)
            batch = torch.stack(batch).t().reshape(-1)
            yield batch
            
    def __len__(self):
        return self.batch_no

In [10]:
class ResnetBlock(nn.Module):
    

    def __init__(self, in_chnls, out_chnls, stride=1):
        super(ResnetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_chnls, out_chnls, kernel_size=3, stride=stride, padding=1, bias=False)
        self.batchn1 = nn.BatchNorm2d(out_chnls)
        self.conv2 = nn.Conv2d(out_chnls, out_chnls, kernel_size=3,stride=1, padding=1, bias=False)
        self.batchn2 = nn.BatchNorm2d(out_chnls)

        self.skip = nn.Sequential()
        if stride != 1 or in_chnls != out_chnls:
            self.skip = nn.Sequential(nn.Conv2d(in_chnls, out_chnls,kernel_size=1, stride=stride, bias=False),
                                      nn.BatchNorm2d(out_chnls))
            

    def forward(self, x):
        out = F.relu(self.batchn1(self.conv1(x)))
        out = self.batchn2(self.conv2(out))
        out += self.skip(x)
        out = F.relu(out)
        return out
    
    
    
class Resnet18(nn.Module):
    def __init__(self,in_chnls,no_classes,resnet_block):
        super(Resnet18,self).__init__()
        self.in_chnls=in_chnls
        self.out_chnls=64
        self.conv1 = nn.Conv2d(in_chnls, self.out_chnls, kernel_size=3,stride=1, padding=1, bias=False)
        self.batchn = nn.BatchNorm2d(64)
        self.l1 = self.build_layer(resnet_block, 64, stride=1)
        self.l2 = self.build_layer(resnet_block, 128, stride=2)
        self.l3 = self.build_layer(resnet_block, 256, stride=2)
        self.l4 = self.build_layer(resnet_block, 512, stride=2)
        self.fc=nn.Linear(512,no_classes)
        
    def build_layer(self, block, chnls, stride):
        layers = []
        strides = [stride] + [1]
        
        for stride in strides:
            layers.append(block(self.out_chnls, chnls, stride))
            self.out_chnls = chnls
        return nn.Sequential(*layers)          
              
    
    def forward(self,x,test=False):
        
        
        x=self.conv1(x)
        x=self.batchn(x)
        x = F.relu(x)
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = F.avg_pool2d(x, 4)
        x=x.reshape(x.shape[0],-1)
        if test: return x
        x = self.fc(x)
        return x
    
    

In [11]:
def train_backbone():
    
    #initializing hyperprameters
    
    learning_rate=0.01
    momentum=0.9
    weight_decay=1e-4
    num_classes = 20
    optim_step_size=30
    optim_gamma=0.1
    num_epochs=40  
    
    #read data
    base_train_dataset=base_dataset(main_path,training=True,size=32)
    base_train_loader = DataLoader(base_train_dataset,  batch_size=64,shuffle=True)
    base_val_dataset=base_dataset(main_path,training=False,size=32)
    base_val_loader = DataLoader(base_val_dataset,  batch_size=32,shuffle=False)
        
    model = Resnet18(3,num_classes,ResnetBlock)
    
    # Define the optimizer and loss function
    
    optim = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)   
    scheduler = StepLR(optim, optim_step_size, gamma=optim_gamma)
    criterion = nn.CrossEntropyLoss()
    
    
    best= 0.0
    for epoch in range(num_epochs):
        model.train()
        scheduler.step()
        tr_losses = []
        tr_acc = []
        
        for imgs, lbls in tqdm(base_train_loader):
            tr_pred = model(imgs)
            loss = criterion(tr_pred, lbls)
            optim.zero_grad()
            loss.backward()
            optim.step()
            tr_losses.append(loss.item())
            tr_acc.append(tr_pred.max(1)[1].eq(lbls).float().mean().item())
        val_acc = []
        model.eval()
        for imgs, lbls in base_val_loader:
            val_pred = model(imgs)
            val_pred = torch.argmax(val_pred, 1).reshape(-1)
            lbls = lbls.reshape(-1)
            val_acc += (val_pred==lbls).tolist()
        val_acc = np.mean(val_acc)
        print(f'epoch:{epoch}, training loss:{np.mean(tr_losses)},training accuracy:{np.mean(tr_acc)}, validation-accuracy:{val_acc}')
        
    return model


In [12]:
model=train_backbone()

100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [07:14<00:00,  2.57s/it]


epoch:0, training loss:2.62016238545525,training accuracy:0.19834812629152332, validation-accuracy:0.30083333333333334


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:27<00:00,  3.00s/it]


epoch:1, training loss:2.2796216645889733,training accuracy:0.2874137080280033, validation-accuracy:0.3908333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:44<00:00,  3.10s/it]


epoch:2, training loss:2.1064731637401692,training accuracy:0.342239891459956, validation-accuracy:0.4033333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [11:08<00:00,  3.95s/it]


epoch:3, training loss:1.9436858405728312,training accuracy:0.39127218934911245, validation-accuracy:0.43333333333333335


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [09:23<00:00,  3.33s/it]


epoch:4, training loss:1.8123803427938878,training accuracy:0.4402120316169671, validation-accuracy:0.4175


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:52<00:00,  3.15s/it]


epoch:5, training loss:1.6712936889490433,training accuracy:0.4780880177514793, validation-accuracy:0.5008333333333334


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [09:24<00:00,  3.34s/it]


epoch:6, training loss:1.5193449791366531,training accuracy:0.53125, validation-accuracy:0.5983333333333334


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:26<00:00,  3.00s/it]


epoch:7, training loss:1.3805347829175418,training accuracy:0.5616987178311545, validation-accuracy:0.57


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:31<00:00,  3.03s/it]


epoch:8, training loss:1.286043105746162,training accuracy:0.5954450196063025, validation-accuracy:0.6383333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:56<00:00,  3.17s/it]


epoch:9, training loss:1.1891794356368703,training accuracy:0.6236748028789046, validation-accuracy:0.6133333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:36<00:00,  3.06s/it]


epoch:10, training loss:1.1225499803497947,training accuracy:0.6468503450500894, validation-accuracy:0.7033333333333334


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [08:36<00:00,  3.06s/it]


epoch:11, training loss:1.0373569695201852,training accuracy:0.6719982740441723, validation-accuracy:0.7008333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [07:37<00:00,  2.71s/it]


epoch:12, training loss:0.9872692569473086,training accuracy:0.6910133136094675, validation-accuracy:0.7383333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:55<00:00,  2.46s/it]


epoch:13, training loss:0.9063086181702699,training accuracy:0.713233481144764, validation-accuracy:0.7383333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:47<00:00,  2.41s/it]


epoch:14, training loss:0.8507826419977041,training accuracy:0.7277490136891427, validation-accuracy:0.6791666666666667


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:47<00:00,  2.41s/it]


epoch:15, training loss:0.8121188953078005,training accuracy:0.7408468934911243, validation-accuracy:0.8016666666666666


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:38<00:00,  2.36s/it]


epoch:16, training loss:0.7686835989444214,training accuracy:0.7608173076923077, validation-accuracy:0.8075


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:32<00:00,  2.32s/it]


epoch:17, training loss:0.7306458997655902,training accuracy:0.7660564596836383, validation-accuracy:0.7691666666666667


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:34<00:00,  2.34s/it]


epoch:18, training loss:0.68245648385505,training accuracy:0.7801097141215082, validation-accuracy:0.7958333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:31<00:00,  2.32s/it]


epoch:19, training loss:0.6650139208023365,training accuracy:0.787444526627219, validation-accuracy:0.795


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:30<00:00,  2.31s/it]


epoch:20, training loss:0.6436292197577347,training accuracy:0.795888806588551, validation-accuracy:0.7975


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:31<00:00,  2.32s/it]


epoch:21, training loss:0.5993769299349135,training accuracy:0.8102502466658869, validation-accuracy:0.7958333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:31<00:00,  2.32s/it]


epoch:22, training loss:0.5762139141735946,training accuracy:0.8137327414997936, validation-accuracy:0.8533333333333334


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:34<00:00,  2.33s/it]


epoch:23, training loss:0.5417249600386479,training accuracy:0.8257211538461539, validation-accuracy:0.8108333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:32<00:00,  2.32s/it]


epoch:24, training loss:0.533044936653425,training accuracy:0.8253205129380762, validation-accuracy:0.82


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:32<00:00,  2.32s/it]


epoch:25, training loss:0.48407608332365926,training accuracy:0.8419625247723957, validation-accuracy:0.8683333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:25<00:00,  2.28s/it]


epoch:26, training loss:0.4827589916407004,training accuracy:0.8416543391329296, validation-accuracy:0.8491666666666666


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:38<00:00,  2.36s/it]


epoch:27, training loss:0.4492290725369425,training accuracy:0.8539201183431953, validation-accuracy:0.8891666666666667


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [07:06<00:00,  2.52s/it]


epoch:28, training loss:0.45702546737955874,training accuracy:0.8504684419321591, validation-accuracy:0.88


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [07:04<00:00,  2.51s/it]


epoch:29, training loss:0.30876417277303675,training accuracy:0.9019662229972478, validation-accuracy:0.9475


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [33:08<00:00, 11.77s/it]


epoch:30, training loss:0.2624849971582198,training accuracy:0.9147250987368928, validation-accuracy:0.9475


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [05:28<00:00,  1.95s/it]


epoch:31, training loss:0.2304660670856047,training accuracy:0.9256348620505023, validation-accuracy:0.9483333333333334


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:50<00:00,  2.43s/it]


epoch:32, training loss:0.22340782700911077,training accuracy:0.9286242603550295, validation-accuracy:0.9508333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:47<00:00,  2.41s/it]


epoch:33, training loss:0.22196836315492202,training accuracy:0.9259738657601486, validation-accuracy:0.9558333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:36<00:00,  2.34s/it]


epoch:34, training loss:0.2076198523919258,training accuracy:0.9341715976331361, validation-accuracy:0.9608333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:21<00:00,  2.26s/it]


epoch:35, training loss:0.19951117827871143,training accuracy:0.9364829881656804, validation-accuracy:0.9566666666666667


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:46<00:00,  2.41s/it]


epoch:36, training loss:0.1986388562787214,training accuracy:0.9350653354232833, validation-accuracy:0.96


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [07:13<00:00,  2.56s/it]


epoch:37, training loss:0.1884184495584499,training accuracy:0.9401503945948809, validation-accuracy:0.9591666666666666


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:56<00:00,  2.46s/it]


epoch:38, training loss:0.1785727602887083,training accuracy:0.9440027119139948, validation-accuracy:0.9658333333333333


100%|████████████████████████████████████████████████████████████████████████████████| 169/169 [06:20<00:00,  2.25s/it]


epoch:39, training loss:0.18238546183476081,training accuracy:0.9405510355029586, validation-accuracy:0.96


In [13]:
def get_features(model, data):
    
    if data.shape[0] > 64:
        feat = []
        count= 0
        while count <= data.shape[0]-1:
            feat.append(model(data[count:count+64], test=True).detach())
            count += 64
        feat = torch.cat(feat)
    else:
        feat = model(data, test=True).detach()
    
    return feat.numpy()

def update_sup( sup_set, X_hat, y_hat, no_sup, pseudoy):
        net = ElasticNet(l1_ratio=1.0,alpha=1.0, normalize=True, fit_intercept=True,selection='cyclic', warm_start=True)
   
        _, coefs, _ = net.path(X_hat, y_hat, l1_ratio=1.0)
        coefs = np.sum(np.abs(coefs.transpose(2, 1, 0)[::-1, no_sup:, :]), axis=2)
        sel = np.zeros(5)
        for coef in coefs:
            for i, c in enumerate(coef):
                if c == 0.0 and (i+no_sup not in sup_set) and (sel[pseudoy[i]] < 1):    
                    sup_set.append(i+no_sup)
                    sel[pseudoy[i]] += 1
            if np.sum(sel >= 1) == 5:
                break
        return sup_set

In [14]:

def get_acc( sup_X, sup_y, query_X, unlabel_X, query_y):

    no_sup =  len(sup_X)
    no_unlabel = unlabel_X.shape[0]
    
    
    sup_unlabel_feat = np.concatenate([sup_X, unlabel_X])
    pca=PCA(n_components=5)  
    X = pca.fit_transform(sup_unlabel_feat)
    
    H = np.dot(np.dot(X, np.linalg.inv(np.dot(X.T, X))), X.T)
    X_hat = np.eye(H.shape[0]) - H

    sup_set = np.arange(no_sup).tolist()
    classifier = LogisticRegression(C=10,solver='lbfgs', multi_class='auto', max_iter=1000)
    classifier.fit(sup_X, sup_y)
    accs = []
    for _ in range(no_sup + no_unlabel):
 
        pseudoy = classifier.predict(unlabel_X)
        y = np.concatenate([sup_y, pseudoy])
        Y = np.zeros((y.shape[0], 5))
        for i, n in enumerate(y):
            Y[i, n] = 1.0
        
        y_hat = np.dot(X_hat, Y)
        sup_set = update_sup(sup_set, X_hat, y_hat, no_sup, pseudoy)
        y = np.argmax(Y, axis=1)
        classifier.fit(sup_unlabel_feat[sup_set], y[sup_set])
        if len(sup_set) == len(sup_unlabel_feat):
            break
    preds = classifier.predict(query_X)
    accs.append(np.mean(preds == query_y))
    return accs
    

In [58]:
def test_classifier(model):
    
    model.eval()                            
    num_batches=20
    
    #read test data                                 
    testdataset=test_dataset(size=32)
    sampler = testSampler(testdataset.test_labels, num_batches)
    test_loader = DataLoader(testdataset, batch_sampler=sampler,shuffle=False,  pin_memory=True)
    
                            
    
                            
    for data, indicator in tqdm(test_loader):
        targets = torch.arange(5).repeat(35).long()[indicator[:175] != 0]
                            
        data = data[indicator != 0]
         
        train_targets = targets[:25].numpy()
        test_targets = targets[25:100].numpy()
                            
        train_feat = get_features(model, data[:25])
        sup_X=normalize( train_feat ) 
        sup_y=train_targets
                            
        test_feat = get_features(model, data[25:100])
        
        unlabel_feat = get_features(model, data[100:])
        query_feat = normalize(test_feat)
        unlabel_feat = normalize(unlabel_feat)                    
        acc = get_acc(sup_X,sup_y,query_feat, unlabel_feat, test_targets)
                         
       
    return acc


In [59]:
accuracy=test_classifier(model)


100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [01:07<00:00,  3.36s/it]


In [63]:
print('Query mean accuracy:',np.mean(accuracy))


Query mean accuracy: 0.72
