In [1]:
import torch
from astra.torch.models import ResNetClassifier,ResNet18_Weights
# from astra.torch.data import load_cifar_10
import torch
import torch.nn as nn
from glob import glob
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from torch.utils.data import TensorDataset, DataLoader
from astra.torch.models import ResNetClassifier,ResNet18_Weights
from astra.torch.utils import train_fn

import torchvision.models as models
from astra.torch.metrics import accuracy_score, f1_score, precision_score, recall_score,classification_report
import torch
import torchvision
from PIL import Image
import torch.nn as nn
import os
from copy import deepcopy
import os
import numpy as np
import torch
from torchvision import datasets, transforms
import warnings

In [2]:
import os
import warnings
import torchvision.datasets
from PIL import Image  # Import the Image module from the PIL library

def prerequisite(f):
    if "TORCH_HOME" not in os.environ:
        os.environ["TORCH_HOME"] = os.path.expanduser("~/.cache/torch")
        warnings.warn(f"TORCH_HOME not set, setting it to {os.environ['TORCH_HOME']}")
    return f

prerequisite(None)
class CIFAR10Instance(torchvision.datasets.CIFAR10):
    """CIFAR10Instance Dataset."""
    def __init__(self, root=f"{os.environ['TORCH_HOME']}/data", train=True, transform=None, target_transform=None, download=True):
        super(CIFAR10Instance, self).__init__(root=root,
                                              train=train,
                                              transform=transform,
                                              target_transform=target_transform)

    def __getitem__(self, index):
        # Your implementation of __getitem__ method
        image, target = self.data[index], self.targets[index]

        # Convert numpy array to PIL Image
        image = Image.fromarray(image)

        if self.transform is not None:
            img = self.transform(image)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index  



In [3]:
import torchvision.transforms as tfs
transform_train = tfs.Compose([
    tfs.Resize(256),
    tfs.RandomResizedCrop(size=224, scale=(0.2, 1.)),
    tfs.ColorJitter(0.4, 0.4, 0.4, 0.4),
    tfs.RandomGrayscale(p=0.2),
    tfs.RandomHorizontalFlip(),
    tfs.ToTensor(),
    tfs.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [4]:
trainset=CIFAR10Instance(root=f"{os.environ['TORCH_HOME']}/data", train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)




In [5]:
import math
import torch.nn as nn

__all__ = [ 'AlexNet', 'alexnet']
 
# (number of filters, kernel size, stride, pad)
CFG = {
    'big': [(96, 11, 4, 2), 'M', (256, 5, 1, 2), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (256, 3, 1, 1), 'M'],
    'small': [(64, 11, 4, 2), 'M', (192, 5, 1, 2), 'M', (384, 3, 1, 1), (256, 3, 1, 1), (256, 3, 1, 1), 'M']
}

class AlexNet(nn.Module):
    def __init__(self, features, num_classes, init=True):
        super(AlexNet, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(nn.Dropout(0.5),
                            nn.Linear(256 * 6 * 6, 4096),
                            nn.ReLU(inplace=True),
                            nn.Dropout(0.5),
                            nn.Linear(4096, 4096),
                            nn.ReLU(inplace=True))
        self.headcount = len(num_classes)
        self.return_features = False
        if len(num_classes) == 1:
            self.top_layer = nn.Linear(4096, num_classes[0])
        else:
            for a,i in enumerate(num_classes):
                setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
            self.top_layer = None  # this way headcount can act as switch.
        if init:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        if self.return_features: # switch only used for CIFAR-experiments
            return x
        if self.headcount == 1:
            if self.top_layer: # this way headcount can act as switch.
                x = self.top_layer(x)
            return x
        else:
            outp = []
            for i in range(self.headcount):
                outp.append(getattr(self, "top_layer%d" % i)(x))
            return outp

    def _initialize_weights(self):
        for y, m in enumerate(self.modules()):
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                for i in range(m.out_channels):
                    m.weight.data[i].normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def make_layers_features(cfg, input_dim, bn):
    layers = []
    in_channels = input_dim
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=3, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v[0], kernel_size=v[1], stride=v[2], padding=v[3])#,bias=False)
            if bn:
                layers += [conv2d, nn.BatchNorm2d(v[0]), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v[0]
    return nn.Sequential(*layers)


def alexnet(bn=True, num_classes=[1000], init=True, size='big'):
    dim = 3
    model = AlexNet(make_layers_features(CFG[size], dim, bn=bn), num_classes, init)
    return model

In [6]:
# dataset = load_cifar_10()
# n_train=10
# n_test=20000
# y=dataset.targets
# x=dataset.data
# classes=dataset.classes
# class_1_idx=classes.index('frog')
# class_1_mask=y==class_1_idx
# y=class_1_mask.byte()
# dataset = load_cifar_10()
# idx=torch.randperm(len(y))
# train_data=x[idx[:n_train]]
# train_targets=y[idx[:n_train]]
# test_data=x[idx[-n_test:]]
# test_targets=y[idx[-n_test:]]
# pool_data=x[idx[n_train:-n_test]]
# pool_targets=y[idx[n_train:-n_test]]
# # train_dataset=TensorDataset(train_data,train_targets)
# # test_dataset=TensorDataset(test_data,test_targets)
# # pool_dataset=TensorDataset(pool_data,pool_targets)
# # train_loader=DataLoader(train_dataset,batch_size=254,shuffle=True)
# # test_loader=DataLoader(test_dataset,batch_size=254,shuffle=False)
# # pool_loader=DataLoader(pool_dataset,batch_size=254,shuffle=True)

In [7]:
# import torchvision.transforms as transforms

# aug = transforms.Compose([
#     transforms.Resize(256),
#     transforms.RandomResizedCrop(224,scale=(0.2,1.0)),
#     transforms.RandomGrayscale(p=0.2),
#     transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
#     transforms.RandomHorizontalFlip(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                          std=[0.229, 0.224, 0.225])
# ])
# augmented_train_data=aug(train_data)
# import matplotlib.pyplot as plt

# # Assuming augmented_train_data is a list of images with shape (3, 32, 32)
# fig, axes = plt.subplots(1, 10, figsize=(15, 3))

# for i in range(10):
#     # Transpose the dimensions to (32, 32, 3) for RGB image
#     image_to_display = augmented_train_data[i].permute(1, 2, 0)
#     image_to_display = (image_to_display - image_to_display.min()) / (image_to_display.max() - image_to_display.min())

#     axes[i].imshow(image_to_display)
#     axes[i].axis('off')

# plt.show()



In [8]:
# from torch.utils.data import DataLoader

# # Assuming augmented_train_data and train_targets are already defined
# train_dataset = TensorDataset(augmented_train_data, train_targets)
# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True,num_workers=8)


In [9]:
import torchvision.models as models
import torch.nn as nn

# train_model = ResNetClassifier(
#     models.resnet18,None, n_classes=5, activation=nn.GELU(), dropout=0.1
# ).to("cuda")

## Initial Random Label Assignment



In [10]:
# train_model.eval()
# with torch.no_grad():
#     y_pred = train_model(augmented_train_data.to("cuda"))
#     print(y_pred.shape)

# print(y_pred)

## 6. Self Labelling with Optimal Transport


Constraint:
The unlabeled images should be divided equally into the K clusters. This is referred to as the equipartition condition in the paper.



In [11]:
# def genarate_optimal_matrix(N,K):
#     images_per_cluster=N//K
#     print(images_per_cluster)
#     matrix=np.zeros((N,K),dtype=int)
#     for j in range(K):
#         start_index=j*images_per_cluster
#         print(start_index)
#         end_index=(j+1)*images_per_cluster
#         matrix[start_index:end_index,j]=1
#     return matrix





In [12]:
# N = len(augmented_train_data)  # Number of unlabeled images
# K = 5  
# genarate_optimal_matrix(N,K)

In [13]:
model=AlexNet(make_layers_features(CFG['big'], 3, bn=True), [5], init=True)

Cost Matrix:
The cost of allocating each image to a cluster is given by the model performance when trained using these clusters as the labels. Intuitively, this means the mistake model is making when we assign an unlabeled image to some cluster. If it is high, then that means our current label assignment is not ideal and so we should change it in the optimization step.

In [14]:
# import torch
# import numpy as np

# # Assuming train_model is your PyTorch model
# # Assuming unlabeled_images is your batch of unlabeled images
# # Assuming optimal_assignment_matrix is the optimal assignment matrix generated previously
# import torch.nn.functional as F
# # Set the model to evaluation mode
# train_model.eval()

# # Move the unlabeled images to the appropriate device (GPU or CPU)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# unlabeled_images = augmented_train_data.to(device)
# optimal_assignment_matrix = genarate_optimal_matrix(N,K)
# # Calculate the model's loss for each image-cluster assignment
# cost_matrix = np.zeros((len(unlabeled_images), optimal_assignment_matrix.shape[1]))

# with torch.no_grad():
#     for i in range(len(unlabeled_images)):
#         image = unlabeled_images[i].unsqueeze(0)  # Add batch dimension
#         for j in range(optimal_assignment_matrix.shape[1]):
#             # Extract the cluster assignment for the current image
#             cluster_assignment = optimal_assignment_matrix[i, j]

#             # Convert cluster_assignment to a PyTorch tensor (1D tensor)
#             cluster_assignment_tensor = torch.tensor([cluster_assignment], dtype=torch.long, device=device)

#             # Forward pass to get model predictions
#             predictions = train_model(image)

#             # Calculate CrossEntropyLoss based on the cluster assignment
#             loss = F.cross_entropy(predictions, cluster_assignment_tensor)
            
#             # Store the loss in the cost matrix
#             cost_matrix[i, j] = loss.item()

# # Display the cost matrix
# print("Cost Matrix:")
# print(cost_matrix)


In [15]:
model=model.to('cuda')

In [16]:
def optimize_l_sk(PS):
    N,K=PS.shape
    print(PS.shape)
    # print(N,K)
    PS=PS.T # now it is K,N
    # print(PS.shape)
    # print(PS)
    r=np.ones((K,1))/K
    # print(r)
    c=np.ones((N,1))/N
    # print(c)
    # print(PS)
    PS**=10
    PS*= np.squeeze(c)
    # print(PS)
    PS=PS.T
    # print(PS)
    PS*= np.squeeze(r)
    # print(PS)
    PS=PS.T
    # print(PS)
    argmax=np.argmax(PS,axis=0)
    print(argmax)
    newLabels=torch.LongTensor(argmax)
    print(newLabels)
    PS=PS.T
    # print(PS)
    PS/= np.squeeze(r)
    # print(PS)
    PS=PS.T
    # print(PS)
    PS/= np.squeeze(c)
    # print(PS)
    sol=PS[argmax,np.arange(N)]
    # print(sol)
    np.log(sol,sol)
    print("sol",sol)
    print("nansum",np.nansum(sol))
    cost=-(1.0/10)*np.nansum(sol)/N
    print("cost",cost)
    return cost,newLabels


In [17]:
from scipy.special import logsumexp
def py_softmax(x, axis=None):
    """stable softmax"""
    return np.exp(x - logsumexp(x, axis=axis, keepdims=True))

In [18]:
def opt_sk(hc,model, selflabels_in, epoch,knn_dim,ncl):
    if hc == 1:
        PS = np.zeros((len(trainloader.dataset), ncl))
    else:
        PS_pre = np.zeros((len(trainloader.dataset), knn_dim))
    for batch_idx, (data, _, _selected) in enumerate(trainloader):
        data = data.cuda()
        if hc == 1:
            p = nn.functional.softmax(model(data), 1)
            PS[_selected, :] = p.detach().cpu().numpy()
        else:
            p = model(data)
            PS_pre[_selected, :] = p.detach().cpu().numpy()
    if hc == 1:
        cost, selflabels = optimize_l_sk(PS)
        _costs = [cost]
    else:
        _nmis = np.zeros(hc)
        _costs = np.zeros(hc)
        nh = epoch % hc  # np.random.randint(args.hc)
        print("computing head %s " % nh, end="\r", flush=True)
        tl = getattr(model, "top_layer%d" % nh)
        # do the forward pass:
        PS = (PS_pre @ tl.weight.cpu().numpy().T
                   + tl.bias.cpu().numpy())
        PS = py_softmax(PS, 1)
        c, selflabels_ = optimize_l_sk(PS)
        _costs[nh] = c
        selflabels_in[nh] = selflabels_
        selflabels = selflabels_in
    return selflabels




In [19]:
opt_sk?

[0;31mSignature:[0m [0mopt_sk[0m[0;34m([0m[0mhc[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0mselflabels_in[0m[0;34m,[0m [0mepoch[0m[0;34m,[0m [0mknn_dim[0m[0;34m,[0m [0mncl[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      /tmp/ipykernel_2487638/940078020.py
[0;31mType:[0m      function

In [20]:
hc=20
model=alexnet(bn=True, num_classes=[5], init=True, size='big')
model=model.to('cuda')
selflabels_in = [None] * hc
epoch=1
knn_dim=5
ncl=5
opt_sk(20,model,1,50,5,128)

computing head 10 

AttributeError: 'AlexNet' object has no attribute 'top_layer10'

In [21]:
import torch
import numpy as np
from torchvision.models import alexnet

numc = hc * ncl
# model = alexnet(bn=True, num_classes=[numc], init=True, size='big')
# model = model.to('cuda')
N = len(trainloader.dataset)

# Init selflabels randomly
if hc == 1:
    selflabels = torch.randint(0, ncl, (N,), dtype=torch.long).cuda()
else:
    selflabels = torch.zeros((hc, N), dtype=torch.long).cuda()
    for nh in range(hc):
        indices = torch.arange(N) % ncl
        shuffled_indices = torch.randperm(N).to('cuda')
        selflabels[nh] = indices.to('cuda')[shuffled_indices]

# Convert selflabels to long tensor
selflabels = selflabels.long().cuda()
print(selflabels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()


tensor([[4, 0, 4,  ..., 1, 1, 4],
        [4, 2, 2,  ..., 4, 1, 3],
        [0, 0, 3,  ..., 0, 2, 1],
        ...,
        [3, 1, 2,  ..., 2, 0, 0],
        [0, 1, 1,  ..., 2, 3, 1],
        [3, 2, 2,  ..., 3, 4, 2]], device='cuda:0')


In [24]:
# import matplotlib.pyplot as plt

# # Assuming trainloader is your data loader for the images
# # Assuming selflabels is the assigned selflabels for the images
# data_iter = iter(trainloader)
# images, _, selflabels = next(data_iter)

# # Assuming image is the image you want to plot
# # Assuming selflabels is the assigned selflabels for the image
# fig, ax = plt.subplots()
# ax.imshow(images[0].permute(1, 2, 0).numpy())  # Convert to NumPy array for visualization
# ax.set_title(f"Self-label: {selflabels[0].item()}")  # Adjust the title as needed
# plt.show()


In [None]:
def train(epoch,selflabels):
    